In [33]:
using Surrogates
using Plots
using Statistics
using Random
using DataFrames
using Distances
using LinearAlgebra
default()

In [51]:
# Define the objective waterflow function (8 dimensions)
function f(x)
    r_w = x[1]
    r = x[2]
    T_u = x[3]
    H_u = x[4]
    T_l = x[5]
    H_l = x[6]
    L = x[7]
    K_w = x[8]
    log_val = log(r/r_w)
    return (2*pi*T_u*(H_u - H_l))/ ( log_val*(1 + (2*L*T_u/(log_val*r_w^2*K_w)) + T_u/T_l))
end


n = 250
initial_n = 50
d = 8
lb = [0.05,100,63070,990,63.1,700,1120,9855]
ub = [0.15,50000,115600,1110,116,820,1680,12045]
original_x = sample(initial_n,lb,ub,SobolSample())
original_y = f.(original_x)

50-element Vector{Float64}:
  16.614887299231814
  66.71434150808243
 145.01472282611837
  50.379068168663856
  45.72674669008426
 218.59877093618695
  88.33929805974748
  31.64730062042864
  37.60355111410419
 141.98371486579754
 146.85835662067186
  43.9988868435236
  38.084914505299544
   ⋮
  80.35171582280238
  50.983847087634985
  40.289796197934024
  84.18332252527621
 132.0679010627322
  78.6406514896613
  46.946759319980025
 161.65406427158032
  52.79670900960365
  26.784637013946686
  33.81519848960852
 115.332266280138

In [3]:
function splitdf(df, pct)
           @assert 0 <= pct <= 1
           ids = collect(axes(df, 1))
           shuffle!(ids)
           sel = ids .<= nrow(df) .* pct
           return DataFrame(view(df, sel, :)), DataFrame(view(df, .!sel, :))
       end

splitdf (generic function with 1 method)

In [24]:
# Implementing diversity
function calculate_variance(x, models)
    predictions = []
        for model in models
            prediction = model(x)
            append!(predictions, prediction)
        end
    return var(predictions, corrected=false)
end

function diversity_metric(prev_x, new_x, models, lambda = 0.5)
    min_dist = Inf
    variance = calculate_variance(new_x, models)
    for point in prev_x
        new_dist = euclidean(point, new_x)
        if new_dist < min_dist
            min_dist = new_dist
        end
    end
    return (1 - lambda) * sqrt(variance) + lambda * min_dist
end

diversity_metric (generic function with 2 methods)

In [43]:
function calculate_error(point, models, actual, mode="MSE")
    target = actual(point)
    errors = []
    for model in models
        prediction = model(point)
        append!(errors, abs(target - prediction))
    end
    if mode == "MSE"
        return mean(errors.^2)
    end
    if mode == "max"
        return maximum(errors)
    end
end

calculate_error (generic function with 2 methods)

In [52]:
total_samples = 1000
prev_points = copy(original_x)
y = copy(original_y)
sample_space = sample(total_samples, lb, ub, SobolSample())
test, orig_train = splitdf(DataFrame(sample_space), 0.2)
train = copy(orig_train)
error = Inf
while error > 47
    kriging_surrogate = Kriging(prev_points, y, lb, ub, p=[1.9, 1.9, 1.9, 1.9, 1.9, 1.9, 1.9, 1.9])
    my_radial_basis = RadialBasis(prev_points, y, lb, ub)
    x = []
    max_score = 0
    max_index = -1
    for j in 1:size(train)[1]
        point = Tuple(collect(train[j, :]))
        score = diversity_metric(prev_points, point, [my_radial_basis, kriging_surrogate], 0)
        if score > max_score
            max_score = score
            max_index = j
        end
    end
    append!(x, Vector([Tuple(collect(train[max_index, :]))]))
    deleteat!(train, max_index)
    prev_points = vcat(prev_points, x)
    y = f.(prev_points)
    errors = []
    for k in 1:size(test)[1]
        point = Tuple(collect(train[k, :]))
        append!(errors, calculate_error(point, [my_radial_basis, kriging_surrogate], f))
    end
    error = mean(errors)
    println(error)
end

21.90001839931228
110.76926205503958
76.24030733502157
33.463873952809294
53.48154675420044
76.24030733502157
31.63677661277999
58.892582094712736
76.24030733502157
84.3003715275262
103.66791277503484
76.24030733502157
29.012865379464007
67.37574406494974
76.24030733502157
96.47143085214874
127.56418043620806
76.24030733502157
194.00964716978868
70.10424330935535
76.24030733502157
64.04710518146591
51.89204435280499
76.24030733502157
41.19797663815929
78.24265944154365
76.24030733502157
180.48507769724392
75.5812581705872
76.24030733502157
27.616697816610326
90.73944223826379
76.24030733502157
113.889696615875
61.51839559040366
76.24030733502157
52.19804826406512
102.67723716043758
76.24030733502157
85.06300859753586
69.15944115399265
76.24030733502157
95.01598970366304
60.59801928425077
76.24030733502157
128.4636639239418
71.22421905377507
76.24030733502157
40.68082617625494
105.47938699423116
76.24030733502157
32.82724515027502
83.41689908609874
76.24030733502157
75.59215809226716
10

21.90001839931228
111.823563786249
76.35075258712286
33.463873952809294
54.290462125038516
76.35075258712286
31.63677661277999
57.79372082938744
76.35075258712286
84.3003715275262
104.19481862524003
76.35075258712286
29.012865379464007
69.68115393251662
76.35075258712286
96.47143085214874
104.92961801149255
76.35075258712286
194.00964716978868
71.77631034450656
76.35075258712286
64.04710518146591
52.08591536015501
76.35075258712286
41.19797663815929
79.7046469907994
76.35075258712286
180.48507769724392
76.04723803499724
76.35075258712286
27.616697816610326
90.04915067689467
76.35075258712286
113.889696615875
63.48086916731802
76.35075258712286
52.19804826406512
102.55915265459464
76.35075258712286
85.06300859753586
70.47146748030741
76.35075258712286
95.01598970366304
60.946565503876286
76.35075258712286
128.4636639239418
69.97336664354793
76.35075258712286
40.68082617625494
105.47643146560188
76.35075258712286
32.82724515027502
86.24348207996422
76.35075258712286
75.59215809226716
105

76.35075258712286
24.225382503214526
80.7486314915185
76.35075258712286
16.38891030508094
96.27441627161204
76.35075258712286
68.4748015114352
92.47837485158936
76.35075258712286
91.1749252331103
60.20363903040152
76.35075258712286
62.12718367157293
68.55425680623648
76.35075258712286
137.00265143949147
44.426473557751365
76.35075258712286
89.41009416579186
62.50779525007988
76.35075258712286
31.469692472193348
119.28298884719288
76.35075258712286
35.28032980298908
71.3183331233347
76.35075258712286
99.09902110199518
47.82423249898261
76.35075258712286
75.08629149879545
80.12091614900305
76.35075258712286
88.63892516034458
54.60153584140943
76.35075258712286
25.93540597180137
124.33728256318068
76.35075258712286
79.4112604025619
114.44750954655547
76.35075258712286
202.42642209152234
81.06813827942358
76.35075258712286
39.45687280323495
62.55072415081963
76.35075258712286
82.09325905369145
56.824808545714745
76.35075258712286
98.70718164274872
116.2742486524221
76.35075258712286
32.052

75.26097494092872
52.42810595297942
35.81274621831358
75.26097494092872
105.78296117670621
110.49112892644507
75.26097494092872
63.85303854738947
78.40167547146689
75.26097494092872
30.944045261598735
80.90936864438481
75.26097494092872
41.97133003638445
77.64045482515712
75.26097494092872
186.3296667002348
71.73594923652854
75.26097494092872
56.74881122299658
112.41146240898752
75.26097494092872
108.49084709562106
70.9191085025019
75.26097494092872
73.71898369643768
50.57077935341022
75.26097494092872
28.600728266124367
68.02742260709681
75.26097494092872
18.20645808686957
38.34561030528698
75.26097494092872
143.05942671536093
102.83002883076006
75.26097494092872
59.21121120977573
56.38016215930429
75.26097494092872
148.9103654389534
111.74944558909897
75.26097494092872
109.13468549265777
77.16579591913796
75.26097494092872
34.72540570398675
60.5860336416132
75.26097494092872
25.593529254928907
66.53131954813489
75.26097494092872
131.6312200888044
58.89345198880983
75.26097494092872
6

74.39374959970476
70.3094361865444
100.76681437260913
74.39374959970476
116.7004901085634
65.38678973073769
74.39374959970476
48.79404246276152
53.23568957781458
74.39374959970476
60.49206553633695
56.81115255379723
74.39374959970476
159.17172939040512
75.29155844377397
74.39374959970476
51.01098310779134
84.35211909072723
74.39374959970476
37.16358464231329
82.28072886623863
74.39374959970476
44.52188765548431
80.6872717469571
74.39374959970476
92.2411522348061
85.05658124783963
74.39374959970476
130.22309210535093
74.85514621577431
74.39374959970476
76.89620769353296
62.37422029981553
74.39374959970476
127.42300045633091
62.71484462432932
74.39374959970476
73.69957102278055
74.10378562503331
74.39374959970476
15.44771268555035
101.91661210149076
74.39374959970476
67.15006909668735
66.47022348554083
74.39374959970476
143.60043340708333
47.799979141265
74.39374959970476
50.247563210349824
71.81098157441238
74.39374959970476
76.61220251926561
63.55750048777213
74.39374959970476
157.7259

74.59111383531116
15.44771268555035
102.05271344651692
74.59111383531116
67.15006909668735
66.1780401363344
74.59111383531116
143.60043340708333
47.63678520053054
74.59111383531116
50.247563210349824
73.14831845536105
74.59111383531116
76.61220251926561
63.20975027706925
74.59111383531116
157.72597982100356
68.00718685431957
74.59111383531116
136.98313998895338
71.98589499541617
74.59111383531116
23.22810151615365
79.78646169770968
74.59111383531116
32.43783905047222
76.55374434864001
74.59111383531116
131.37913731858663
71.48489578762815
74.59111383531116
143.57460283812884
49.19007414567068
74.59111383531116
43.487234245190905
97.73056661569854
74.59111383531116
46.81595207847135
37.99386097477418
74.59111383531116
95.92299771013309
65.91851068311348
74.59111383531116
16.057716326529388
59.6599577904301
74.59111383531116
27.534780854866266
50.88358221454064
74.59111383531116
52.42810595297942
35.81838518750044
74.59111383531116
105.78296117670621
90.34449273074642
74.59111383531116
6

LoadError: InterruptException:

In [53]:
println(size(prev_points))

(55,)


In [38]:
n_test = 1000
x_test = sample(n_test,lb,ub,GoldenSample());
y_true = f.(x_test);
my_rad = RadialBasis(prev_points, y, lb, ub)
my_krig = Kriging(prev_points, y, lb, ub, p=[1.9, 1.9, 1.9, 1.9, 1.9, 1.9, 1.9, 1.9])
y_rad = my_rad.(x_test)
y_krig = my_krig.(x_test);

In [39]:
mse_rad = norm(y_true - y_rad,2)/n_test
mse_krig = norm(y_true - y_krig,2)/n_test
print("MSE RadialBasis: $mse_rad    ")
print("MSE Kriging: $mse_krig    ")

MSE RadialBasis: 1.5698430282200377    MSE Kriging: 1.4261068005903885    