In [9]:
from hypersurface_tf import *
from generate_h import *
from biholoNN import *
import tensorflow as tf
import numpy as np

### Prepare the dataset:

In [10]:
z0, z1, z2, z3, z4 = sp.symbols('z0, z1, z2, z3, z4')
Z = [z0,z1,z2,z3,z4]
f = z0**5 + z1**5 + z2**5 + z3**5 + z4**5 + 0.5*z0*z1*z2*z3*z4
np.random.seed(123)
HS = Hypersurface(Z, f, 10000)
np.random.seed(124)
HS_test = Hypersurface(Z, f, 10000)

In [11]:
train_set = generate_dataset(HS)
test_set = generate_dataset(HS_test)









In [12]:
train_set = train_set.shuffle(500000).batch(1000)
test_set = test_set.shuffle(50000).batch(1000)

### Build the model:

In [21]:
class KahlerPotential(tf.keras.Model):

    def __init__(self):
        super(KahlerPotential, self).__init__()
        self.biholomorphic = Biholomorphic()
        self.layer1 = tf.keras.layers.Dense(100, activation=tf.square, use_bias=False)
        self.layer2 = tf.keras.layers.Dense(100, activation=tf.square, use_bias=False)
        #self.layer3 = tf.keras.layers.Dense(100, activation=tf.square, use_bias=False)
        #self.layer_4 = ComplexDense(50, 10, activation=tf.square)
        #self.layer_3 = ComplexDense(10, 15, activation=tf.square)
        #self.g = ComplexG(70)
        
    def call(self, inputs):
        x = self.biholomorphic(inputs)
        x = self.layer1(x)
        x = self.layer2(x)
        #x = self.layer3(x)
        #x = self.layer_4(x)
        #x = self.g(x)
        #x = tf.linalg.diag_part(tf.matmul(x, x, adjoint_b=True))
        #x = tf.math.log(x)
        x = tf.reduce_sum(x, 1)
        x = tf.math.log(x)
        return x

In [22]:
model = KahlerPotential()

In [23]:
@tf.function
def volume_form(x, Omega_Omegabar, mass, restriction):

    kahler_metric = complex_hessian(tf.math.real(model(x)), x)
    volume_form = tf.math.real(tf.linalg.det(tf.matmul(restriction, tf.matmul(kahler_metric, restriction, adjoint_b=True))))
    weights = mass / tf.reduce_sum(mass)
    factor = tf.reduce_sum(weights * volume_form / Omega_Omegabar)
    #factor = tf.constant(35.1774, dtype=tf.complex64)
    return volume_form / factor

In [24]:
#optimizer = tf.keras.optimizers.SGD(learning_rate=1e-1)
#learning_rate = 1
n_layers = 3
epochs = 6000

alpha = 0.001
beta1 = 0.9
beta2 = 0.999
eta = 1e-8

m = [0]*n_layers
v = [0]*n_layers
t = 0

for epoch in range(epochs):
    for step, (points, Omega_Omegabar, mass, restriction) in enumerate(train_set):
        with tf.GradientTape() as tape:
            
            omega = volume_form(points, Omega_Omegabar, mass, restriction)
            loss = weighted_MAPE(Omega_Omegabar, omega, mass)  
            grads = tape.gradient(loss, model.trainable_weights)
            #print(grads)

        t = t + 1
        i = 0
        for weight, grad in zip(model.trainable_weights, grads):
            
            m[i] = beta1 * m[i] + (1 - beta1) * grad
            v[i] = beta2 * v[i] + (1 - beta2) * grad**2
  
            alpha_t = alpha * np.sqrt(1 - beta2**t) / (1 - beta1**t)
            theta = alpha_t * m[i] / (tf.sqrt(v[i]) + eta)
            #print(theta)
            
            i = i + 1
            
            weight.assign_sub(theta)
    
        if step % 500 == 0:
            print("step %d: loss = %.4f" % (step, loss))
    
    test_loss = 0
    test_loss_old = 100
    
    for step, (points, Omega_Omegabar, mass, restriction) in enumerate(test_set):
        omega = volume_form(points, Omega_Omegabar, mass, restriction)
        test_loss += weighted_MAPE(Omega_Omegabar, omega, mass)
   
    test_loss = tf.math.real(test_loss).numpy()/(step+1)
    print("test_loss:", test_loss)
    
    # This part doesn't work right now
    if test_loss > test_loss_old:
        break
    test_loss_old = test_loss

step 0: loss = 0.6070
test_loss: 0.09932825088500977
step 0: loss = 0.0978
test_loss: 0.06426638126373291
step 0: loss = 0.0657
test_loss: 0.05879887104034424
step 0: loss = 0.0564
test_loss: 0.05617036342620849
step 0: loss = 0.0546
test_loss: 0.05415566444396973
step 0: loss = 0.0538
test_loss: 0.05303669452667236
step 0: loss = 0.0513
test_loss: 0.05133878231048584
step 0: loss = 0.0507
test_loss: 0.04886788845062256
step 0: loss = 0.0483
test_loss: 0.04734538555145264
step 0: loss = 0.0462
test_loss: 0.04464642524719238
step 0: loss = 0.0418
test_loss: 0.042682137489318844
step 0: loss = 0.0413
test_loss: 0.04032407760620117
step 0: loss = 0.0377
test_loss: 0.037876405715942384
step 0: loss = 0.0357
test_loss: 0.03602229118347168
step 0: loss = 0.0340
test_loss: 0.034733164310455325
step 0: loss = 0.0316
test_loss: 0.033231563568115234
step 0: loss = 0.0326
test_loss: 0.03134039878845215
step 0: loss = 0.0303
test_loss: 0.030648682117462158
step 0: loss = 0.0297
test_loss: 0.029255

step 0: loss = 0.0143
test_loss: 0.014604041576385498
step 0: loss = 0.0135
test_loss: 0.01464722990989685
step 0: loss = 0.0133
test_loss: 0.014336119890213012
step 0: loss = 0.0132
test_loss: 0.01455539584159851
step 0: loss = 0.0131
test_loss: 0.014279202222824097
step 0: loss = 0.0138
test_loss: 0.014589097499847412
step 0: loss = 0.0136
test_loss: 0.014656957387924194
step 0: loss = 0.0142
test_loss: 0.014251896142959596
step 0: loss = 0.0127
test_loss: 0.01470661759376526
step 0: loss = 0.0146
test_loss: 0.014461288452148438
step 0: loss = 0.0135
test_loss: 0.014452280998229981
step 0: loss = 0.0131
test_loss: 0.01420001745223999
step 0: loss = 0.0130
test_loss: 0.01477858543395996
step 0: loss = 0.0135
test_loss: 0.014307210445404053
step 0: loss = 0.0141
test_loss: 0.014119254350662231
step 0: loss = 0.0137
test_loss: 0.01390238881111145
step 0: loss = 0.0125
test_loss: 0.014655587673187255
step 0: loss = 0.0133
test_loss: 0.014095629453659058
step 0: loss = 0.0139
test_loss: 0

step 0: loss = 0.0118
test_loss: 0.012855548858642579
step 0: loss = 0.0122
test_loss: 0.012555787563323975
step 0: loss = 0.0113
test_loss: 0.012621641159057617
step 0: loss = 0.0114
test_loss: 0.013076064586639404
step 0: loss = 0.0127
test_loss: 0.013016525506973266
step 0: loss = 0.0118
test_loss: 0.013001410961151124
step 0: loss = 0.0132
test_loss: 0.012733206748962403
step 0: loss = 0.0122
test_loss: 0.012698205709457398
step 0: loss = 0.0113
test_loss: 0.012925715446472167
step 0: loss = 0.0120
test_loss: 0.012793762683868408
step 0: loss = 0.0117
test_loss: 0.012715648412704467
step 0: loss = 0.0120
test_loss: 0.0126580810546875
step 0: loss = 0.0115
test_loss: 0.012709435224533081
step 0: loss = 0.0117
test_loss: 0.012709603309631348
step 0: loss = 0.0112
test_loss: 0.01283471941947937
step 0: loss = 0.0125
test_loss: 0.013153269290924072
step 0: loss = 0.0129
test_loss: 0.012661217451095582
step 0: loss = 0.0115
test_loss: 0.012532508373260498
step 0: loss = 0.0115
test_loss

test_loss: 0.01232213020324707
step 0: loss = 0.0115
test_loss: 0.012053852081298827
step 0: loss = 0.0111
test_loss: 0.01229580283164978
step 0: loss = 0.0116
test_loss: 0.01229584813117981
step 0: loss = 0.0112
test_loss: 0.012347056865692138
step 0: loss = 0.0113
test_loss: 0.012108452320098877
step 0: loss = 0.0116
test_loss: 0.012210074663162231
step 0: loss = 0.0110
test_loss: 0.011870735883712768
step 0: loss = 0.0114
test_loss: 0.012217087745666504
step 0: loss = 0.0116
test_loss: 0.01262924551963806
step 0: loss = 0.0111
test_loss: 0.012287839651107788
step 0: loss = 0.0113
test_loss: 0.01221219778060913
step 0: loss = 0.0109
test_loss: 0.012246286869049073
step 0: loss = 0.0113
test_loss: 0.012103725671768189
step 0: loss = 0.0111
test_loss: 0.012257717847824097
step 0: loss = 0.0117
test_loss: 0.01205466628074646
step 0: loss = 0.0108
test_loss: 0.011964136362075805
step 0: loss = 0.0110
test_loss: 0.01202862024307251
step 0: loss = 0.0116
test_loss: 0.011995714902877808
ste

step 0: loss = 0.0110
test_loss: 0.012300717830657958
step 0: loss = 0.0111
test_loss: 0.011715128421783447
step 0: loss = 0.0107
test_loss: 0.011873233318328857
step 0: loss = 0.0109
test_loss: 0.011850458383560181
step 0: loss = 0.0110
test_loss: 0.011875067949295043
step 0: loss = 0.0107
test_loss: 0.011894387006759644
step 0: loss = 0.0109
test_loss: 0.011773816347122192
step 0: loss = 0.0108
test_loss: 0.01178048610687256
step 0: loss = 0.0112
test_loss: 0.011778953075408936
step 0: loss = 0.0103
test_loss: 0.011644768714904784
step 0: loss = 0.0113
test_loss: 0.01179140329360962
step 0: loss = 0.0114
test_loss: 0.0116320538520813
step 0: loss = 0.0109
test_loss: 0.011862061023712157
step 0: loss = 0.0110
test_loss: 0.01182376742362976
step 0: loss = 0.0109
test_loss: 0.01168299674987793
step 0: loss = 0.0103
test_loss: 0.011917111873626708
step 0: loss = 0.0112
test_loss: 0.011730700731277466
step 0: loss = 0.0105
test_loss: 0.011661229133605957
step 0: loss = 0.0112
test_loss: 0

test_loss: 0.011271333694458008
step 0: loss = 0.0099
test_loss: 0.0115838623046875
step 0: loss = 0.0103
test_loss: 0.011475112438201904
step 0: loss = 0.0108
test_loss: 0.011384128332138062
step 0: loss = 0.0103
test_loss: 0.01123572826385498
step 0: loss = 0.0103
test_loss: 0.011474486589431763
step 0: loss = 0.0105
test_loss: 0.011337437629699708
step 0: loss = 0.0108
test_loss: 0.011461586952209472
step 0: loss = 0.0108
test_loss: 0.011700071096420288
step 0: loss = 0.0108
test_loss: 0.011448899507522583
step 0: loss = 0.0098
test_loss: 0.011663172245025635
step 0: loss = 0.0108
test_loss: 0.011235851049423217
step 0: loss = 0.0105
test_loss: 0.011395037174224854
step 0: loss = 0.0104
test_loss: 0.01101369857788086
step 0: loss = 0.0104
test_loss: 0.011409794092178344
step 0: loss = 0.0097
test_loss: 0.0112982177734375
step 0: loss = 0.0100
test_loss: 0.011384934186935425
step 0: loss = 0.0108
test_loss: 0.01137187957763672
step 0: loss = 0.0101
test_loss: 0.011234095096588135
ste

step 0: loss = 0.0099
test_loss: 0.010840804576873779
step 0: loss = 0.0095
test_loss: 0.010872359275817872
step 0: loss = 0.0098
test_loss: 0.010754210948944092
step 0: loss = 0.0101
test_loss: 0.010916249752044678
step 0: loss = 0.0095
test_loss: 0.010837202072143554
step 0: loss = 0.0099
test_loss: 0.010936692953109742
step 0: loss = 0.0096
test_loss: 0.010625579357147218
step 0: loss = 0.0091
test_loss: 0.010554314851760864
step 0: loss = 0.0098
test_loss: 0.010861893892288208
step 0: loss = 0.0093
test_loss: 0.010809582471847535
step 0: loss = 0.0098
test_loss: 0.010739225149154662
step 0: loss = 0.0093
test_loss: 0.010794166326522827
step 0: loss = 0.0098
test_loss: 0.010977808237075806
step 0: loss = 0.0096
test_loss: 0.010709964036941529
step 0: loss = 0.0096
test_loss: 0.010960996150970459
step 0: loss = 0.0100
test_loss: 0.010933706760406494
step 0: loss = 0.0098
test_loss: 0.010748269557952881
step 0: loss = 0.0096
test_loss: 0.010872316360473634
step 0: loss = 0.0098
test_l

test_loss: 0.010398112535476685
step 0: loss = 0.0097
test_loss: 0.010486541986465455
step 0: loss = 0.0095
test_loss: 0.010274497270584106
step 0: loss = 0.0096
test_loss: 0.010048015117645264
step 0: loss = 0.0091
test_loss: 0.010105626583099365
step 0: loss = 0.0091
test_loss: 0.01005695343017578
step 0: loss = 0.0100
test_loss: 0.010067170858383179
step 0: loss = 0.0092
test_loss: 0.010076462030410767
step 0: loss = 0.0094
test_loss: 0.010148364305496215
step 0: loss = 0.0093
test_loss: 0.010049372911453247
step 0: loss = 0.0093
test_loss: 0.00997318685054779
step 0: loss = 0.0096
test_loss: 0.01014478325843811
step 0: loss = 0.0093
test_loss: 0.010368337631225586
step 0: loss = 0.0096
test_loss: 0.009982883930206299
step 0: loss = 0.0092
test_loss: 0.010176297426223755
step 0: loss = 0.0098
test_loss: 0.01016790270805359
step 0: loss = 0.0095
test_loss: 0.010052508115768433
step 0: loss = 0.0092
test_loss: 0.010000114440917968
step 0: loss = 0.0093
test_loss: 0.010083096027374268


step 0: loss = 0.0084
test_loss: 0.009372406601905824
step 0: loss = 0.0084
test_loss: 0.009231780767440795
step 0: loss = 0.0080
test_loss: 0.009273610115051269
step 0: loss = 0.0084
test_loss: 0.009386776089668274
step 0: loss = 0.0088
test_loss: 0.009411943554878234
step 0: loss = 0.0087
test_loss: 0.009438936114311219
step 0: loss = 0.0084
test_loss: 0.009419693946838378
step 0: loss = 0.0085
test_loss: 0.009491570591926575
step 0: loss = 0.0084
test_loss: 0.009373973608016967
step 0: loss = 0.0087
test_loss: 0.009362724423408509
step 0: loss = 0.0089
test_loss: 0.009395543336868286
step 0: loss = 0.0086
test_loss: 0.009217519760131836
step 0: loss = 0.0082
test_loss: 0.009242109060287475
step 0: loss = 0.0087
test_loss: 0.009395062327384948
step 0: loss = 0.0084
test_loss: 0.00944710373878479
step 0: loss = 0.0084
test_loss: 0.009281663298606873
step 0: loss = 0.0087
test_loss: 0.009522305727005005
step 0: loss = 0.0083
test_loss: 0.009217580556869507
step 0: loss = 0.0084
test_lo

test_loss: 0.00855210304260254
step 0: loss = 0.0078
test_loss: 0.00887513518333435
step 0: loss = 0.0082
test_loss: 0.008714357018470764
step 0: loss = 0.0081
test_loss: 0.008925659656524658
step 0: loss = 0.0085
test_loss: 0.008693567514419555
step 0: loss = 0.0081
test_loss: 0.008551061153411865
step 0: loss = 0.0080
test_loss: 0.008715958595275878
step 0: loss = 0.0081
test_loss: 0.008821208477020264
step 0: loss = 0.0085
test_loss: 0.008538338541984557
step 0: loss = 0.0074
test_loss: 0.008773186206817628
step 0: loss = 0.0082
test_loss: 0.008911653757095336
step 0: loss = 0.0083
test_loss: 0.008738818764686584
step 0: loss = 0.0085
test_loss: 0.008651597499847412
step 0: loss = 0.0079
test_loss: 0.008779150247573853
step 0: loss = 0.0085
test_loss: 0.008676458597183228
step 0: loss = 0.0081
test_loss: 0.008601388335227967
step 0: loss = 0.0075
test_loss: 0.008539427518844605
step 0: loss = 0.0079
test_loss: 0.008708089590072632
step 0: loss = 0.0083
test_loss: 0.00872660279273986

step 0: loss = 0.0078
test_loss: 0.008154905438423156
step 0: loss = 0.0078
test_loss: 0.008083112835884095
step 0: loss = 0.0078
test_loss: 0.008215517401695252
step 0: loss = 0.0076
test_loss: 0.008389523029327392
step 0: loss = 0.0079
test_loss: 0.008347931504249572
step 0: loss = 0.0075
test_loss: 0.008160091042518615
step 0: loss = 0.0072
test_loss: 0.008238874077796936
step 0: loss = 0.0077
test_loss: 0.008208262920379638
step 0: loss = 0.0076
test_loss: 0.008251801729202271
step 0: loss = 0.0076
test_loss: 0.008171581029891968
step 0: loss = 0.0077
test_loss: 0.008287410140037537
step 0: loss = 0.0079
test_loss: 0.008173862099647522
step 0: loss = 0.0071
test_loss: 0.008124948143959046
step 0: loss = 0.0075
test_loss: 0.008393983840942382
step 0: loss = 0.0075
test_loss: 0.008278454542160035
step 0: loss = 0.0078
test_loss: 0.008220037817955017
step 0: loss = 0.0079
test_loss: 0.008222497701644897
step 0: loss = 0.0077
test_loss: 0.008234515190124511
step 0: loss = 0.0076
test_l

test_loss: 0.008087474107742309
step 0: loss = 0.0074
test_loss: 0.007819663286209106
step 0: loss = 0.0070
test_loss: 0.007862188816070557
step 0: loss = 0.0076
test_loss: 0.008120529055595398
step 0: loss = 0.0071
test_loss: 0.008019642233848571
step 0: loss = 0.0072
test_loss: 0.007862632274627685
step 0: loss = 0.0071
test_loss: 0.007867833375930786
step 0: loss = 0.0073
test_loss: 0.00785421371459961
step 0: loss = 0.0073
test_loss: 0.007788245677947998
step 0: loss = 0.0071
test_loss: 0.007971317172050477
step 0: loss = 0.0070
test_loss: 0.00786730945110321
step 0: loss = 0.0074
test_loss: 0.008063468933105469
step 0: loss = 0.0072
test_loss: 0.008124005794525147
step 0: loss = 0.0081
test_loss: 0.008204606771469116
step 0: loss = 0.0073
test_loss: 0.007961530089378357
step 0: loss = 0.0073
test_loss: 0.00812263548374176
step 0: loss = 0.0077
test_loss: 0.008022791743278503
step 0: loss = 0.0075
test_loss: 0.007841119766235352
step 0: loss = 0.0073
test_loss: 0.00806826114654541


step 0: loss = 0.0072
test_loss: 0.008007927536964417
step 0: loss = 0.0073
test_loss: 0.007801474928855896
step 0: loss = 0.0073
test_loss: 0.007953295707702637
step 0: loss = 0.0074
test_loss: 0.007806025147438049
step 0: loss = 0.0068
test_loss: 0.007752736806869507
step 0: loss = 0.0070
test_loss: 0.007765102982521057
step 0: loss = 0.0071
test_loss: 0.007957691550254822
step 0: loss = 0.0072
test_loss: 0.00793093740940094
step 0: loss = 0.0072
test_loss: 0.00795095145702362
step 0: loss = 0.0074
test_loss: 0.00785582721233368
step 0: loss = 0.0073
test_loss: 0.007832999229431153
step 0: loss = 0.0073
test_loss: 0.0077026402950286866
step 0: loss = 0.0070
test_loss: 0.007769097685813904
step 0: loss = 0.0068
test_loss: 0.007862519025802612
step 0: loss = 0.0070
test_loss: 0.007937241196632385
step 0: loss = 0.0071
test_loss: 0.007725223302841187
step 0: loss = 0.0072
test_loss: 0.008031325340270996
step 0: loss = 0.0074
test_loss: 0.007801476716995239
step 0: loss = 0.0071
test_los

step 0: loss = 0.0071
test_loss: 0.007806910872459411
step 0: loss = 0.0070
test_loss: 0.0076975500583648685
step 0: loss = 0.0071
test_loss: 0.0077471107244491575
step 0: loss = 0.0070
test_loss: 0.007868832349777222
step 0: loss = 0.0077
test_loss: 0.007770164608955383
step 0: loss = 0.0075
test_loss: 0.007743797898292541
step 0: loss = 0.0072
test_loss: 0.007756677865982056
step 0: loss = 0.0071
test_loss: 0.007802667617797851
step 0: loss = 0.0074
test_loss: 0.007753002643585205
step 0: loss = 0.0072
test_loss: 0.007950816750526429
step 0: loss = 0.0070
test_loss: 0.007808334231376648
step 0: loss = 0.0071
test_loss: 0.007898271679878235
step 0: loss = 0.0072
test_loss: 0.0077426421642303464
step 0: loss = 0.0070
test_loss: 0.007540010809898376
step 0: loss = 0.0072
test_loss: 0.00792822778224945
step 0: loss = 0.0075
test_loss: 0.00800668478012085
step 0: loss = 0.0075
test_loss: 0.007734202742576599
step 0: loss = 0.0075
test_loss: 0.007737597227096557
step 0: loss = 0.0070
test_

step 0: loss = 0.0068
test_loss: 0.007766430377960205
step 0: loss = 0.0071
test_loss: 0.007768632769584656
step 0: loss = 0.0071
test_loss: 0.007552317976951599
step 0: loss = 0.0068
test_loss: 0.007530996799468994
step 0: loss = 0.0072
test_loss: 0.007628249526023864
step 0: loss = 0.0069
test_loss: 0.007838433980941773
step 0: loss = 0.0077
test_loss: 0.007735054492950439
step 0: loss = 0.0071
test_loss: 0.007927101850509644
step 0: loss = 0.0072
test_loss: 0.007580289840698242
step 0: loss = 0.0071
test_loss: 0.007505001425743103
step 0: loss = 0.0073
test_loss: 0.007402669191360473
step 0: loss = 0.0066
test_loss: 0.007600072622299194
step 0: loss = 0.0070
test_loss: 0.007565281987190246
step 0: loss = 0.0071
test_loss: 0.007725988626480103
step 0: loss = 0.0076
test_loss: 0.007725962400436401
step 0: loss = 0.0069
test_loss: 0.007710299491882324
step 0: loss = 0.0069
test_loss: 0.007649818062782287
step 0: loss = 0.0071
test_loss: 0.007623597979545593
step 0: loss = 0.0069
test_l

test_loss: 0.007718824744224548
step 0: loss = 0.0070
test_loss: 0.007653070092201233
step 0: loss = 0.0072
test_loss: 0.0075219708681106565
step 0: loss = 0.0072
test_loss: 0.007568563222885132
step 0: loss = 0.0067
test_loss: 0.007545222043991089
step 0: loss = 0.0070
test_loss: 0.007550870776176452
step 0: loss = 0.0066
test_loss: 0.007504303455352784
step 0: loss = 0.0069
test_loss: 0.007594735026359558
step 0: loss = 0.0070
test_loss: 0.007570546865463257
step 0: loss = 0.0072
test_loss: 0.00760083794593811
step 0: loss = 0.0074
test_loss: 0.0075711715221405025
step 0: loss = 0.0067
test_loss: 0.00758699893951416
step 0: loss = 0.0069
test_loss: 0.007631179690361023
step 0: loss = 0.0070
test_loss: 0.00750180721282959
step 0: loss = 0.0068
test_loss: 0.007480441927909851
step 0: loss = 0.0067
test_loss: 0.00756780743598938
step 0: loss = 0.0064
test_loss: 0.007712128162384033
step 0: loss = 0.0072
test_loss: 0.007478553056716919
step 0: loss = 0.0069
test_loss: 0.00759211540222167

step 0: loss = 0.0072
test_loss: 0.007456585764884949
step 0: loss = 0.0070
test_loss: 0.007446190118789673
step 0: loss = 0.0072
test_loss: 0.007812107801437378
step 0: loss = 0.0073
test_loss: 0.007437871694564819
step 0: loss = 0.0066
test_loss: 0.00763231635093689
step 0: loss = 0.0074
test_loss: 0.007490851283073426
step 0: loss = 0.0070
test_loss: 0.007499351501464844
step 0: loss = 0.0070
test_loss: 0.007584636807441711
step 0: loss = 0.0068
test_loss: 0.007509540319442749
step 0: loss = 0.0072
test_loss: 0.007569836378097534
step 0: loss = 0.0072
test_loss: 0.007652132511138916
step 0: loss = 0.0070
test_loss: 0.007594484090805054
step 0: loss = 0.0069
test_loss: 0.0075802308320999144
step 0: loss = 0.0066
test_loss: 0.007369371056556702
step 0: loss = 0.0065
test_loss: 0.00750413715839386
step 0: loss = 0.0067
test_loss: 0.0076344400644302365
step 0: loss = 0.0073
test_loss: 0.007574498057365418
step 0: loss = 0.0068
test_loss: 0.007555763125419617
step 0: loss = 0.0068
test_l

step 0: loss = 0.0071
test_loss: 0.007561483979225158
step 0: loss = 0.0071
test_loss: 0.007507469058036804
step 0: loss = 0.0068
test_loss: 0.007628784775733948
step 0: loss = 0.0071
test_loss: 0.007499600648880005
step 0: loss = 0.0067
test_loss: 0.007657656073570251
step 0: loss = 0.0074
test_loss: 0.007545730471611023
step 0: loss = 0.0068
test_loss: 0.007458775639533997
step 0: loss = 0.0071
test_loss: 0.007617742419242859
step 0: loss = 0.0068
test_loss: 0.007491406202316284
step 0: loss = 0.0069
test_loss: 0.007425532937049866
step 0: loss = 0.0066
test_loss: 0.00749030590057373
step 0: loss = 0.0071
test_loss: 0.007629766464233399
step 0: loss = 0.0070
test_loss: 0.007405996918678283
step 0: loss = 0.0068
test_loss: 0.007513694763183594
step 0: loss = 0.0066
test_loss: 0.007456520795822144
step 0: loss = 0.0074
test_loss: 0.00745265007019043
step 0: loss = 0.0067
test_loss: 0.007489431500434875
step 0: loss = 0.0073
test_loss: 0.007431566119194031
step 0: loss = 0.0070
test_los

step 0: loss = 0.0072
test_loss: 0.007412884831428528
step 0: loss = 0.0066
test_loss: 0.007396809458732605
step 0: loss = 0.0067
test_loss: 0.00741973340511322
step 0: loss = 0.0069
test_loss: 0.007657722234725952
step 0: loss = 0.0074
test_loss: 0.007593307495117187
step 0: loss = 0.0072
test_loss: 0.00739227831363678
step 0: loss = 0.0068
test_loss: 0.0074299460649490355
step 0: loss = 0.0068
test_loss: 0.00747467815876007
step 0: loss = 0.0068
test_loss: 0.007377723455429077
step 0: loss = 0.0066
test_loss: 0.007351070046424866
step 0: loss = 0.0068
test_loss: 0.007520743608474731
step 0: loss = 0.0069
test_loss: 0.007329471111297607
step 0: loss = 0.0064
test_loss: 0.007448987364768982
step 0: loss = 0.0067
test_loss: 0.0074336987733840945
step 0: loss = 0.0068
test_loss: 0.007382505536079407
step 0: loss = 0.0066
test_loss: 0.0073910045623779294
step 0: loss = 0.0069
test_loss: 0.00732887864112854
step 0: loss = 0.0065
test_loss: 0.0074414688348770144
step 0: loss = 0.0070
test_l

step 0: loss = 0.0069
test_loss: 0.007295066118240356
step 0: loss = 0.0067
test_loss: 0.007381826043128968
step 0: loss = 0.0068
test_loss: 0.007337989807128906
step 0: loss = 0.0069
test_loss: 0.007338492274284363
step 0: loss = 0.0068
test_loss: 0.007499850988388061
step 0: loss = 0.0072
test_loss: 0.007526130080223083
step 0: loss = 0.0067
test_loss: 0.007295888662338257
step 0: loss = 0.0067
test_loss: 0.007514209151268005
step 0: loss = 0.0068
test_loss: 0.0074039483070373535
step 0: loss = 0.0064
test_loss: 0.007308778762817383
step 0: loss = 0.0067
test_loss: 0.007506663799285889
step 0: loss = 0.0070
test_loss: 0.007566205263137817
step 0: loss = 0.0072
test_loss: 0.007369909882545472
step 0: loss = 0.0068
test_loss: 0.00733225166797638
step 0: loss = 0.0067
test_loss: 0.007357751727104187
step 0: loss = 0.0068
test_loss: 0.007417133450508118
step 0: loss = 0.0068
test_loss: 0.007387086153030396
step 0: loss = 0.0069
test_loss: 0.00735380470752716
step 0: loss = 0.0068
test_lo

step 0: loss = 0.0072
test_loss: 0.007372123003005981
step 0: loss = 0.0075
test_loss: 0.007497678399085998
step 0: loss = 0.0068
test_loss: 0.007398601174354553
step 0: loss = 0.0065
test_loss: 0.007379490733146667
step 0: loss = 0.0066
test_loss: 0.007398941516876221
step 0: loss = 0.0070
test_loss: 0.007457020282745361
step 0: loss = 0.0067
test_loss: 0.0074291789531707765
step 0: loss = 0.0066
test_loss: 0.007375529408454895
step 0: loss = 0.0064
test_loss: 0.007378502488136291
step 0: loss = 0.0066
test_loss: 0.007461173534393311
step 0: loss = 0.0072
test_loss: 0.007337658405303955
step 0: loss = 0.0068
test_loss: 0.00729154646396637
step 0: loss = 0.0065
test_loss: 0.007310603260993958
step 0: loss = 0.0068
test_loss: 0.007356213331222534
step 0: loss = 0.0068
test_loss: 0.007459335327148438
step 0: loss = 0.0067
test_loss: 0.007353471517562866
step 0: loss = 0.0067
test_loss: 0.007327964901924134
step 0: loss = 0.0065
test_loss: 0.007285002470016479
step 0: loss = 0.0066
test_l

step 0: loss = 0.0066
test_loss: 0.007519039511680603
step 0: loss = 0.0070
test_loss: 0.007401856184005737
step 0: loss = 0.0067
test_loss: 0.007292234301567078
step 0: loss = 0.0067
test_loss: 0.007226125001907348
step 0: loss = 0.0063
test_loss: 0.007324571013450623
step 0: loss = 0.0062
test_loss: 0.007316399216651916
step 0: loss = 0.0065
test_loss: 0.007335711121559143
step 0: loss = 0.0067
test_loss: 0.007244875431060791
step 0: loss = 0.0065
test_loss: 0.007374393939971924
step 0: loss = 0.0070
test_loss: 0.007357509136199951
step 0: loss = 0.0070
test_loss: 0.007259456515312195
step 0: loss = 0.0066
test_loss: 0.007337398529052734
step 0: loss = 0.0070
test_loss: 0.007373843789100647
step 0: loss = 0.0068
test_loss: 0.0074986833333969116
step 0: loss = 0.0066
test_loss: 0.007316855788230896
step 0: loss = 0.0069
test_loss: 0.007343579530715943
step 0: loss = 0.0067
test_loss: 0.007284387350082398
step 0: loss = 0.0065
test_loss: 0.007213298678398133
step 0: loss = 0.0065
test_

test_loss: 0.007580656409263611
step 0: loss = 0.0076
test_loss: 0.007261404395103454
step 0: loss = 0.0066
test_loss: 0.0071904569864273075
step 0: loss = 0.0066
test_loss: 0.007337878942489624
step 0: loss = 0.0069
test_loss: 0.00734557569026947
step 0: loss = 0.0067
test_loss: 0.007252202033996582
step 0: loss = 0.0065
test_loss: 0.007280610203742981
step 0: loss = 0.0069
test_loss: 0.007274398803710938
step 0: loss = 0.0066
test_loss: 0.007446699738502502
step 0: loss = 0.0066
test_loss: 0.007381008267402649
step 0: loss = 0.0067
test_loss: 0.007516281008720398
step 0: loss = 0.0067
test_loss: 0.007177793979644775
step 0: loss = 0.0065
test_loss: 0.0072316288948059086
step 0: loss = 0.0066
test_loss: 0.007288329005241394
step 0: loss = 0.0069
test_loss: 0.00726102888584137
step 0: loss = 0.0068
test_loss: 0.0071990108489990235
step 0: loss = 0.0064
test_loss: 0.007385075688362121
step 0: loss = 0.0068
test_loss: 0.007393034100532532
step 0: loss = 0.0067
test_loss: 0.00719145178794

step 0: loss = 0.0064
test_loss: 0.007224013805389404
step 0: loss = 0.0063
test_loss: 0.007328650951385498
step 0: loss = 0.0065
test_loss: 0.007179969549179077
step 0: loss = 0.0063
test_loss: 0.0073207104206085205
step 0: loss = 0.0065
test_loss: 0.0073129040002822875
step 0: loss = 0.0065
test_loss: 0.007274292111396789
step 0: loss = 0.0063
test_loss: 0.007412489056587219
step 0: loss = 0.0069
test_loss: 0.007177384495735168
step 0: loss = 0.0064
test_loss: 0.00728963553905487
step 0: loss = 0.0068
test_loss: 0.007232493162155152
step 0: loss = 0.0064
test_loss: 0.007237452268600464
step 0: loss = 0.0067
test_loss: 0.007309420108795166
step 0: loss = 0.0072
test_loss: 0.007355222105979919
step 0: loss = 0.0064
test_loss: 0.007366428375244141
step 0: loss = 0.0065
test_loss: 0.0074811190366745
step 0: loss = 0.0071
test_loss: 0.007397143244743347
step 0: loss = 0.0071
test_loss: 0.00730400562286377
step 0: loss = 0.0067
test_loss: 0.007225134372711182
step 0: loss = 0.0064
test_los

test_loss: 0.007260974049568176
step 0: loss = 0.0065
test_loss: 0.007196415066719055
step 0: loss = 0.0067
test_loss: 0.007259511351585388
step 0: loss = 0.0068
test_loss: 0.007204980850219727
step 0: loss = 0.0065
test_loss: 0.007175691127777099
step 0: loss = 0.0063
test_loss: 0.007355198860168457
step 0: loss = 0.0066
test_loss: 0.007260179519653321
step 0: loss = 0.0067
test_loss: 0.007289548516273499
step 0: loss = 0.0066
test_loss: 0.007284354567527771
step 0: loss = 0.0069
test_loss: 0.007151367664337158
step 0: loss = 0.0066
test_loss: 0.007101219892501831
step 0: loss = 0.0070
test_loss: 0.007191224694252014
step 0: loss = 0.0066
test_loss: 0.007347407341003418
step 0: loss = 0.0071
test_loss: 0.0071829754114151
step 0: loss = 0.0064
test_loss: 0.007315191626548767
step 0: loss = 0.0066
test_loss: 0.00720499038696289
step 0: loss = 0.0070
test_loss: 0.007183167338371277
step 0: loss = 0.0064
test_loss: 0.007284337282180786
step 0: loss = 0.0069
test_loss: 0.007221901416778564

test_loss: 0.007284278869628906
step 0: loss = 0.0067
test_loss: 0.00715767502784729
step 0: loss = 0.0067
test_loss: 0.007222236394882202
step 0: loss = 0.0064
test_loss: 0.007181307673454285
step 0: loss = 0.0067
test_loss: 0.007277373671531678
step 0: loss = 0.0070
test_loss: 0.007318930625915527
step 0: loss = 0.0069
test_loss: 0.007228391170501709
step 0: loss = 0.0072
test_loss: 0.007105883359909058
step 0: loss = 0.0065
test_loss: 0.007334537506103515
step 0: loss = 0.0068
test_loss: 0.007153111696243286
step 0: loss = 0.0066
test_loss: 0.0071962857246398925
step 0: loss = 0.0064
test_loss: 0.007235568165779114
step 0: loss = 0.0065
test_loss: 0.007215235233306885
step 0: loss = 0.0069
test_loss: 0.007317765355110168
step 0: loss = 0.0067
test_loss: 0.007200872302055359
step 0: loss = 0.0066
test_loss: 0.007204765677452088
step 0: loss = 0.0068
test_loss: 0.007292532920837402
step 0: loss = 0.0068
test_loss: 0.007286049723625183
step 0: loss = 0.0065
test_loss: 0.007265218496322

step 0: loss = 0.0065
test_loss: 0.007084613442420959
step 0: loss = 0.0066
test_loss: 0.0072765350341796875
step 0: loss = 0.0065
test_loss: 0.00717840850353241
step 0: loss = 0.0066
test_loss: 0.0072084182500839235
step 0: loss = 0.0065
test_loss: 0.007270919680595398
step 0: loss = 0.0069
test_loss: 0.007113423347473145
step 0: loss = 0.0064
test_loss: 0.007109050750732422
step 0: loss = 0.0066
test_loss: 0.00721243441104889
step 0: loss = 0.0065
test_loss: 0.0071212238073349
step 0: loss = 0.0065
test_loss: 0.0071258538961410526
step 0: loss = 0.0065
test_loss: 0.007152566909790039
step 0: loss = 0.0067
test_loss: 0.007189176082611084
step 0: loss = 0.0067
test_loss: 0.007176584601402283
step 0: loss = 0.0068
test_loss: 0.007162790298461914
step 0: loss = 0.0065
test_loss: 0.007124680280685425
step 0: loss = 0.0067
test_loss: 0.007230540513992309
step 0: loss = 0.0064
test_loss: 0.007177801728248596
step 0: loss = 0.0065
test_loss: 0.007339293360710144
step 0: loss = 0.0069
test_lo

step 0: loss = 0.0065
test_loss: 0.007088668346405029
step 0: loss = 0.0061
test_loss: 0.007117195129394531
step 0: loss = 0.0064
test_loss: 0.007181736230850219
step 0: loss = 0.0068
test_loss: 0.007123434543609619
step 0: loss = 0.0065
test_loss: 0.007120313644409179
step 0: loss = 0.0063
test_loss: 0.007272140979766846
step 0: loss = 0.0067
test_loss: 0.007251309752464294
step 0: loss = 0.0067
test_loss: 0.007267163395881653
step 0: loss = 0.0066
test_loss: 0.007274624109268189
step 0: loss = 0.0070
test_loss: 0.0073009192943573
step 0: loss = 0.0065
test_loss: 0.007128223180770874
step 0: loss = 0.0067
test_loss: 0.007181137204170227
step 0: loss = 0.0066
test_loss: 0.007286390662193299
step 0: loss = 0.0063
test_loss: 0.007243632078170776
step 0: loss = 0.0066
test_loss: 0.0073079907894134525
step 0: loss = 0.0070
test_loss: 0.007245160937309265
step 0: loss = 0.0065
test_loss: 0.007255789041519165
step 0: loss = 0.0069
test_loss: 0.007043866515159607
step 0: loss = 0.0065
test_lo

step 0: loss = 0.0060
test_loss: 0.007117990851402283
step 0: loss = 0.0066
test_loss: 0.007118674516677856
step 0: loss = 0.0066
test_loss: 0.007232027649879455
step 0: loss = 0.0065
test_loss: 0.007204571962356568
step 0: loss = 0.0065
test_loss: 0.0071440738439559935
step 0: loss = 0.0064
test_loss: 0.007100012302398682
step 0: loss = 0.0065
test_loss: 0.007136197090148926
step 0: loss = 0.0064
test_loss: 0.007200296521186828
step 0: loss = 0.0063
test_loss: 0.007073160409927368
step 0: loss = 0.0065
test_loss: 0.007135769724845887
step 0: loss = 0.0065
test_loss: 0.007217938899993896
step 0: loss = 0.0064
test_loss: 0.007226162552833557
step 0: loss = 0.0067
test_loss: 0.007068595886230469
step 0: loss = 0.0061
test_loss: 0.00703535258769989
step 0: loss = 0.0066
test_loss: 0.007222074270248413
step 0: loss = 0.0066
test_loss: 0.007149312496185303
step 0: loss = 0.0066
test_loss: 0.007119871973991394
step 0: loss = 0.0067
test_loss: 0.0071729934215545655
step 0: loss = 0.0068
test_

test_loss: 0.00706695020198822
step 0: loss = 0.0065
test_loss: 0.007126185297966004
step 0: loss = 0.0069
test_loss: 0.007191649675369263
step 0: loss = 0.0067
test_loss: 0.007214884161949158
step 0: loss = 0.0067
test_loss: 0.007128168344497681
step 0: loss = 0.0064
test_loss: 0.007218003869056701
step 0: loss = 0.0064
test_loss: 0.007209978103637695
step 0: loss = 0.0061
test_loss: 0.007259039282798767
step 0: loss = 0.0063
test_loss: 0.007164142727851867
step 0: loss = 0.0064
test_loss: 0.007094274759292602
step 0: loss = 0.0065
test_loss: 0.007230829000473022
step 0: loss = 0.0063
test_loss: 0.007186207175254822
step 0: loss = 0.0066
test_loss: 0.00709056556224823
step 0: loss = 0.0064
test_loss: 0.007023504376411438
step 0: loss = 0.0058
test_loss: 0.007137428522109985
step 0: loss = 0.0063
test_loss: 0.007098371386528015
step 0: loss = 0.0062
test_loss: 0.007189568281173706
step 0: loss = 0.0065
test_loss: 0.007085736393928527
step 0: loss = 0.0065
test_loss: 0.00709701061248779

test_loss: 0.007148106694221497
step 0: loss = 0.0068
test_loss: 0.0071296805143356325
step 0: loss = 0.0065
test_loss: 0.006973568201065063
step 0: loss = 0.0064
test_loss: 0.007130130529403687
step 0: loss = 0.0066
test_loss: 0.007172419428825378
step 0: loss = 0.0064
test_loss: 0.007152055501937866
step 0: loss = 0.0065
test_loss: 0.007161804437637329
step 0: loss = 0.0066
test_loss: 0.007267546653747558
step 0: loss = 0.0067
test_loss: 0.007155967950820923
step 0: loss = 0.0062
test_loss: 0.007232142686843872
step 0: loss = 0.0065
test_loss: 0.007115540504455567
step 0: loss = 0.0062
test_loss: 0.007099190354347229
step 0: loss = 0.0063
test_loss: 0.007094418406486511
step 0: loss = 0.0063
test_loss: 0.007087306380271912
step 0: loss = 0.0063
test_loss: 0.0071023541688919065
step 0: loss = 0.0065
test_loss: 0.007169934511184692
step 0: loss = 0.0066
test_loss: 0.007039833664894104
step 0: loss = 0.0066
test_loss: 0.007007317543029785
step 0: loss = 0.0062
test_loss: 0.0071324437856

step 0: loss = 0.0063
test_loss: 0.007118214964866638
step 0: loss = 0.0066
test_loss: 0.007167964577674866
step 0: loss = 0.0068
test_loss: 0.0071815353631973266
step 0: loss = 0.0064
test_loss: 0.006991901397705078
step 0: loss = 0.0064
test_loss: 0.007189565300941467
step 0: loss = 0.0070
test_loss: 0.007079840898513794
step 0: loss = 0.0063
test_loss: 0.007191955447196961
step 0: loss = 0.0067
test_loss: 0.007123261094093323
step 0: loss = 0.0065
test_loss: 0.007105764746665954
step 0: loss = 0.0063
test_loss: 0.007084302306175232
step 0: loss = 0.0065
test_loss: 0.007075803875923156
step 0: loss = 0.0064
test_loss: 0.007054941058158875
step 0: loss = 0.0065
test_loss: 0.0070797944068908695
step 0: loss = 0.0066
test_loss: 0.007207038402557373
step 0: loss = 0.0071
test_loss: 0.007148467898368835
step 0: loss = 0.0065
test_loss: 0.007044929265975952
step 0: loss = 0.0068
test_loss: 0.007238890528678894
step 0: loss = 0.0064
test_loss: 0.007068529129028321
step 0: loss = 0.0063
test

step 0: loss = 0.0064
test_loss: 0.007135406732559204
step 0: loss = 0.0063
test_loss: 0.007026495337486267
step 0: loss = 0.0063
test_loss: 0.007119653224945069
step 0: loss = 0.0064
test_loss: 0.007051026821136475
step 0: loss = 0.0064
test_loss: 0.007069064974784851
step 0: loss = 0.0065
test_loss: 0.007107229828834534
step 0: loss = 0.0065
test_loss: 0.007067221403121948
step 0: loss = 0.0064
test_loss: 0.007129796147346497
step 0: loss = 0.0064
test_loss: 0.006998395323753357
step 0: loss = 0.0066
test_loss: 0.007142983078956604
step 0: loss = 0.0067
test_loss: 0.007081297636032105
step 0: loss = 0.0070
test_loss: 0.00708304762840271
step 0: loss = 0.0064
test_loss: 0.007171440124511719
step 0: loss = 0.0064
test_loss: 0.0071115452051162716
step 0: loss = 0.0065
test_loss: 0.007251394391059876
step 0: loss = 0.0069
test_loss: 0.007111252546310425
step 0: loss = 0.0061
test_loss: 0.007006093859672546
step 0: loss = 0.0062
test_loss: 0.007180189490318298
step 0: loss = 0.0065
test_l

step 0: loss = 0.0068
test_loss: 0.007034664750099182
step 0: loss = 0.0063
test_loss: 0.007189445495605469
step 0: loss = 0.0068
test_loss: 0.0070099538564682
step 0: loss = 0.0062
test_loss: 0.007156102657318115
step 0: loss = 0.0065
test_loss: 0.007137086391448975
step 0: loss = 0.0067
test_loss: 0.007045384049415588
step 0: loss = 0.0063
test_loss: 0.007019435167312622
step 0: loss = 0.0064
test_loss: 0.007122959494590759
step 0: loss = 0.0067
test_loss: 0.00704383134841919
step 0: loss = 0.0065
test_loss: 0.006982791423797608
step 0: loss = 0.0065
test_loss: 0.0070962607860565186
step 0: loss = 0.0066
test_loss: 0.00710097849369049
step 0: loss = 0.0066
test_loss: 0.007111571431159973
step 0: loss = 0.0065
test_loss: 0.007111772894859314
step 0: loss = 0.0063
test_loss: 0.007121503949165344
step 0: loss = 0.0064
test_loss: 0.007083995938301086
step 0: loss = 0.0063
test_loss: 0.007133997082710266
step 0: loss = 0.0064
test_loss: 0.007151498198509216
step 0: loss = 0.0063
test_loss

step 0: loss = 0.0065
test_loss: 0.006987430453300476
step 0: loss = 0.0067
test_loss: 0.007105467319488526
step 0: loss = 0.0066
test_loss: 0.007133679389953613
step 0: loss = 0.0066
test_loss: 0.0071617597341537475
step 0: loss = 0.0066
test_loss: 0.007031846046447754
step 0: loss = 0.0062
test_loss: 0.007006741762161255
step 0: loss = 0.0065
test_loss: 0.007147594690322876
step 0: loss = 0.0064
test_loss: 0.0071385037899017335
step 0: loss = 0.0062
test_loss: 0.0070074242353439335
step 0: loss = 0.0063
test_loss: 0.007119261026382447
step 0: loss = 0.0065
test_loss: 0.007031545639038086
step 0: loss = 0.0063
test_loss: 0.007166115045547486
step 0: loss = 0.0065
test_loss: 0.007028471827507019
step 0: loss = 0.0064
test_loss: 0.007062638401985169
step 0: loss = 0.0063
test_loss: 0.007099561095237732
step 0: loss = 0.0066
test_loss: 0.007068392038345337
step 0: loss = 0.0065
test_loss: 0.007225245237350464
step 0: loss = 0.0067
test_loss: 0.007063790559768677
step 0: loss = 0.0065
tes

step 0: loss = 0.0067
test_loss: 0.006985004544258117
step 0: loss = 0.0063
test_loss: 0.0069847780466079715
step 0: loss = 0.0060
test_loss: 0.007022766470909119
step 0: loss = 0.0061
test_loss: 0.007020527720451355
step 0: loss = 0.0064
test_loss: 0.007048904895782471
step 0: loss = 0.0064
test_loss: 0.007062922716140747
step 0: loss = 0.0064
test_loss: 0.007141086459159851
step 0: loss = 0.0068
test_loss: 0.007056018114089966
step 0: loss = 0.0060
test_loss: 0.007056000828742981
step 0: loss = 0.0064
test_loss: 0.007026289701461792
step 0: loss = 0.0066
test_loss: 0.006980472207069397
step 0: loss = 0.0065
test_loss: 0.007074512243270874
step 0: loss = 0.0065
test_loss: 0.007065949440002442
step 0: loss = 0.0065
test_loss: 0.007070410847663879
step 0: loss = 0.0063
test_loss: 0.0070761370658874514
step 0: loss = 0.0066
test_loss: 0.007127672433853149
step 0: loss = 0.0064
test_loss: 0.007128627300262451
step 0: loss = 0.0065
test_loss: 0.007027390003204345
step 0: loss = 0.0067
test

test_loss: 0.006966246962547302
step 0: loss = 0.0062
test_loss: 0.007041746377944946
step 0: loss = 0.0063
test_loss: 0.00695745050907135
step 0: loss = 0.0063
test_loss: 0.006996049284934997
step 0: loss = 0.0061
test_loss: 0.006998875141143799
step 0: loss = 0.0061
test_loss: 0.006998748183250427
step 0: loss = 0.0065
test_loss: 0.007036427259445191
step 0: loss = 0.0063
test_loss: 0.007074309587478638
step 0: loss = 0.0065
test_loss: 0.007033665776252747
step 0: loss = 0.0066
test_loss: 0.006999208331108093
step 0: loss = 0.0064
test_loss: 0.006966820955276489
step 0: loss = 0.0064
test_loss: 0.00700967013835907
step 0: loss = 0.0066
test_loss: 0.007016467452049255
step 0: loss = 0.0064
test_loss: 0.007168960571289062
step 0: loss = 0.0063
test_loss: 0.00699995756149292
step 0: loss = 0.0063
test_loss: 0.007131304740905762
step 0: loss = 0.0068
test_loss: 0.006984120011329651
step 0: loss = 0.0066
test_loss: 0.007054160237312317
step 0: loss = 0.0064
test_loss: 0.007028155922889709

step 0: loss = 0.0067
test_loss: 0.0070181721448898315
step 0: loss = 0.0068
test_loss: 0.007042770385742188
step 0: loss = 0.0065
test_loss: 0.0069933515787124634
step 0: loss = 0.0065
test_loss: 0.00703650951385498
step 0: loss = 0.0065
test_loss: 0.007051728963851929
step 0: loss = 0.0061
test_loss: 0.00699335515499115
step 0: loss = 0.0064
test_loss: 0.007127501368522644
step 0: loss = 0.0066
test_loss: 0.006936307549476623
step 0: loss = 0.0063
test_loss: 0.007038989067077637
step 0: loss = 0.0064
test_loss: 0.007152613401412964
step 0: loss = 0.0067
test_loss: 0.00696435272693634
step 0: loss = 0.0063
test_loss: 0.006983663439750671
step 0: loss = 0.0062
test_loss: 0.007034767270088196
step 0: loss = 0.0069
test_loss: 0.006943516731262207
step 0: loss = 0.0065
test_loss: 0.006917015910148621
step 0: loss = 0.0061
test_loss: 0.007020125985145569
step 0: loss = 0.0067
test_loss: 0.007040620446205139
step 0: loss = 0.0062
test_loss: 0.007029950022697449
step 0: loss = 0.0063
test_lo

step 0: loss = 0.0066
test_loss: 0.007069270014762878
step 0: loss = 0.0062
test_loss: 0.007030076384544373
step 0: loss = 0.0061
test_loss: 0.007072692513465881
step 0: loss = 0.0062
test_loss: 0.006982313990592956
step 0: loss = 0.0067
test_loss: 0.0070342087745666505
step 0: loss = 0.0063
test_loss: 0.00696378767490387
step 0: loss = 0.0062
test_loss: 0.007054601311683655
step 0: loss = 0.0063
test_loss: 0.007034944891929626
step 0: loss = 0.0065
test_loss: 0.007093643546104431
step 0: loss = 0.0063
test_loss: 0.006941607594490052
step 0: loss = 0.0065
test_loss: 0.006901085972785949
step 0: loss = 0.0066
test_loss: 0.007105556726455689
step 0: loss = 0.0063
test_loss: 0.006955658197402954
step 0: loss = 0.0064
test_loss: 0.00704873263835907
step 0: loss = 0.0062
test_loss: 0.007033311128616333
step 0: loss = 0.0064
test_loss: 0.007001749277114868
step 0: loss = 0.0064
test_loss: 0.007088199257850647
step 0: loss = 0.0066
test_loss: 0.007029955387115479
step 0: loss = 0.0064
test_lo

step 0: loss = 0.0059
test_loss: 0.007018402814865112
step 0: loss = 0.0063
test_loss: 0.006974183320999146
step 0: loss = 0.0064
test_loss: 0.007028198838233948
step 0: loss = 0.0063
test_loss: 0.007010753154754639
step 0: loss = 0.0064
test_loss: 0.006989938616752625
step 0: loss = 0.0064
test_loss: 0.006944034099578857
step 0: loss = 0.0063
test_loss: 0.0069385749101638795
step 0: loss = 0.0063
test_loss: 0.006959058046340942
step 0: loss = 0.0062
test_loss: 0.007022362947463989
step 0: loss = 0.0063
test_loss: 0.006914392709732055
step 0: loss = 0.0060
test_loss: 0.006998027563095093
step 0: loss = 0.0061
test_loss: 0.007120726108551026
step 0: loss = 0.0065
test_loss: 0.006988462209701538
step 0: loss = 0.0063
test_loss: 0.0069773197174072266
step 0: loss = 0.0063
test_loss: 0.006948070526123047
step 0: loss = 0.0065
test_loss: 0.0070305800437927245
step 0: loss = 0.0064
test_loss: 0.007030363082885742
step 0: loss = 0.0066
test_loss: 0.007033449411392212
step 0: loss = 0.0066
tes

### Some debugging tests:

In [None]:
@tf.function
def volume_form(x, Omega_Omegabar, mass, restriction):

    kahler_metric = complex_hessian(tf.math.real(model(x)), x)
    volume_form = tf.linalg.det(tf.matmul(restriction, tf.matmul(kahler_metric, restriction, adjoint_b=True)))
    weights = mass / tf.reduce_sum(mass)
    factor = tf.reduce_sum(weights * volume_form / Omega_Omegabar)
    #factor = tf.constant(4.380538, dtype=tf.complex64)
    return  volume_form/factor
    #return factor
for step, (points, Omega_Omegabar, mass, restriction) in enumerate(dataset):
    omega = volume_form(points, Omega_Omegabar, mass, restriction)
    
    weights = mass / tf.reduce_sum(mass)
    print('omega', omega)
    print('OO',Omega_Omegabar)
    print(tf.cast(tf.abs(Omega_Omegabar -  omega), dtype=tf.complex64) / Omega_Omegabar)
   # print(mass/tf.reduce_sum(mass))

In [16]:
for weight in model.trainable_weights:
    print(weight)

<tf.Variable 'Variable:0' shape=(5, 15) dtype=complex64, numpy=
array([[ 2.2768168e+00+1.6352530e+00j,  5.8658237e+00-4.2718673e-01j,
         5.3873414e-01-1.7244281e-01j, -3.9534414e-01+7.2521999e-02j,
        -1.8923687e-02+1.1802919e-02j, -3.8347977e-01-2.2627316e-01j,
        -7.6662725e-01+8.0790210e-01j,  2.6793966e-01+6.7697294e-02j,
        -4.4106621e-01-1.5242220e-01j, -1.6198709e+00-5.7326740e-01j,
         2.8095651e+00+8.1460869e-01j, -1.1343197e+00-2.5883400e+00j,
        -4.6649609e+00+1.1073344e+00j,  1.7060937e-01-2.2714563e-02j,
         8.5827917e-01-7.9530555e-01j],
       [ 1.9180773e+00+4.4432992e-01j, -1.7982172e+00+1.0105304e+00j,
        -6.0589733e+00-8.5629034e-01j,  2.9838247e+00-3.7910095e-01j,
         2.8432586e+00+9.5829725e-01j, -4.8352519e-01-4.6606249e-01j,
        -1.7036586e+00-9.6777737e-01j, -3.2158126e-03+1.6477600e-03j,
         1.7850419e+00+6.1775154e-01j,  4.3784446e-01-8.8729739e-02j,
         2.8238079e-01+1.9228293e-01j, -2.9208440e-01+1.

In [21]:
class KahlerPotential(tf.keras.Model):

    def __init__(self):
        super(KahlerPotential, self).__init__()
        self.biholomorphic = Biholomorphic()
        self.layer1 = tf.keras.layers.Dense(100, activation=tf.square, use_bias=False)


        #self.layer_4 = ComplexDense(50, 10, activation=tf.square)
        #self.layer_3 = ComplexDense(10, 15, activation=tf.square)
        #self.g = ComplexG(70)
        
    def call(self, inputs):
        x = self.biholomorphic(inputs)
        x = self.layer1(x)
        #x = self.layer_2(x)

        #x = self.layer_4(x)
        #x = self.g(x)
        #x = tf.linalg.diag_part(tf.matmul(x, x, adjoint_b=True))
        #x = tf.math.log(x)
        x = tf.reduce_sum(x, 1)
        x = tf.math.log(x)
        return x

In [6]:
model = KahlerPotential()

In [13]:
for step, (points, Omega_Omegabar, mass, restriction) in enumerate(train_set):
    print(tf.shape(model(points)))
    if step == 0:
        a = model(points)

tf.Tensor([1000], shape=(1,), dtype=int32)
tf.Tensor([1000], shape=(1,), dtype=int32)
tf.Tensor([1000], shape=(1,), dtype=int32)
tf.Tensor([1000], shape=(1,), dtype=int32)
tf.Tensor([1000], shape=(1,), dtype=int32)
tf.Tensor([1000], shape=(1,), dtype=int32)
tf.Tensor([1000], shape=(1,), dtype=int32)
tf.Tensor([1000], shape=(1,), dtype=int32)
tf.Tensor([1000], shape=(1,), dtype=int32)
tf.Tensor([1000], shape=(1,), dtype=int32)
tf.Tensor([1000], shape=(1,), dtype=int32)
tf.Tensor([1000], shape=(1,), dtype=int32)
tf.Tensor([1000], shape=(1,), dtype=int32)
tf.Tensor([1000], shape=(1,), dtype=int32)
tf.Tensor([1000], shape=(1,), dtype=int32)
tf.Tensor([1000], shape=(1,), dtype=int32)
tf.Tensor([1000], shape=(1,), dtype=int32)
tf.Tensor([1000], shape=(1,), dtype=int32)
tf.Tensor([1000], shape=(1,), dtype=int32)
tf.Tensor([1000], shape=(1,), dtype=int32)
tf.Tensor([1000], shape=(1,), dtype=int32)
tf.Tensor([1000], shape=(1,), dtype=int32)
tf.Tensor([1000], shape=(1,), dtype=int32)
tf.Tensor([

In [9]:
a = tf.transpose(a)
a = tf.reduce_sum(tf.abs(a), 1)

In [8]:
# TODO: Change Batchsize & Network structure

<tf.Tensor: shape=(1000, 1), dtype=float32, numpy=
array([[5.38316548e-01],
       [8.83447230e-02],
       [5.26222169e-01],
       [7.88766980e-01],
       [1.48804903e-01],
       [7.92228580e-01],
       [2.03350872e-01],
       [6.20962977e-01],
       [1.69604346e-01],
       [8.83976184e-03],
       [6.39118850e-02],
       [1.27156880e-02],
       [6.97958589e-01],
       [3.64552498e-01],
       [2.37538330e-02],
       [1.51279286e-01],
       [1.08250463e+00],
       [3.38911235e-01],
       [2.04938009e-01],
       [4.25134525e-02],
       [8.43953907e-01],
       [1.02830306e-03],
       [7.30025589e-01],
       [4.43385094e-01],
       [3.14193010e-01],
       [1.46317527e-01],
       [1.05443120e+00],
       [9.39430669e-02],
       [2.25267470e-01],
       [1.41357586e-01],
       [2.44012162e-01],
       [3.09310365e-03],
       [3.00194952e-03],
       [1.89879518e-02],
       [1.02200031e-01],
       [3.23994249e-01],
       [1.09909177e-01],
       [2.04617605e-01],

In [26]:
a**2

<tf.Tensor: shape=(2,), dtype=int32, numpy=array([1, 4], dtype=int32)>

In [26]:
a_real = tf.math.real(a)
a_imag = tf.math.imag(a)
a_new = tf.concat([a_real, a_imag], axis=1)

In [35]:
boolean_mask = tf.cast(a, dtype=tf.bool)              
no_zeros = tf.ragged.boolean_mask(a, boolean_mask)

In [9]:
tf.shape(a)

<tf.Tensor: shape=(2,), dtype=int32, numpy=array([1000,   50], dtype=int32)>

In [30]:
a_real[0]

<tf.Tensor: shape=(25,), dtype=float32, numpy=
array([ 0.9038784 ,  0.84662837, -0.09585851, -0.663743  , -0.43639532,
        0.8466283 ,  0.95462674, -0.0439431 , -0.8095745 , -0.05158594,
       -0.09585851, -0.0439431 ,  0.02316958,  0.01710205,  0.14759119,
       -0.663743  , -0.8095745 ,  0.01710205,  0.7057893 , -0.09472027,
       -0.43639532, -0.05158594,  0.14759119, -0.09472027,  1.        ],
      dtype=float32)>

In [25]:
a = tf.constant([1, 2])