In [1]:
from hypersurface_tf import *
from generate_h import *
from complexNN import *
import tensorflow as tf

### Prepare the dataset:

In [2]:
z0, z1, z2, z3, z4 = sp.symbols('z0, z1, z2, z3, z4')
Z = [z0,z1,z2,z3,z4]
f = z0**5 + z1**5 + z2**5 + z3**5 + z4**5 + 0.5*z0*z1*z2*z3*z4
HS = Hypersurface(Z, f, 100000)
HS_test = Hypersurface(Z, f, 10000)

In [3]:
train_set = generate_dataset(HS)
test_set = generate_dataset(HS_test)







In [4]:
train_set = train_set.shuffle(500000).batch(1000)
test_set = test_set.shuffle(50000).batch(1000)

### Build the model:

In [5]:
class KahlerPotential(tf.keras.Model):

    def __init__(self):
        super(KahlerPotential, self).__init__()
        self.layer_1 = ComplexDense(5, 60, activation=tf.square)
        self.layer_2 = ComplexDense(60, 60, activation=tf.square)

        #self.layer_4 = ComplexDense(50, 10, activation=tf.square)
        #self.layer_3 = ComplexDense(10, 15, activation=tf.square)
        self.g = ComplexG(60)
        
    def call(self, inputs):
        x = self.layer_1(inputs)
        x = self.layer_2(x)

        #x = self.layer_4(x)
        x = self.g(x)
        x = tf.linalg.diag_part(tf.matmul(x, x, adjoint_b=True))
        x = tf.math.log(x)
        return x

In [6]:
model = KahlerPotential()

In [7]:
@tf.function
def volume_form(x, Omega_Omegabar, mass, restriction):

    kahler_metric = complex_hessian(tf.math.real(model(x)), x)
    volume_form = tf.linalg.det(tf.matmul(restriction, tf.matmul(kahler_metric, restriction, adjoint_b=True)))
    weights = mass / tf.reduce_sum(mass)
    factor = tf.reduce_sum(weights * volume_form / Omega_Omegabar)
    #factor = tf.constant(4.380538, dtype=tf.complex64)
    return volume_form / factor

In [None]:
#optimizer = tf.keras.optimizers.SGD(learning_rate=1e-1)
learning_rate = 10
epochs = 20000

for epoch in range(epochs):
    for step, (points, Omega_Omegabar, mass, restriction) in enumerate(train_set):
        with tf.GradientTape() as tape:
            
            omega = volume_form(points, Omega_Omegabar, mass, restriction)
            loss = weighted_MAPE(Omega_Omegabar, omega, mass)  
            grads = tape.gradient(loss, model.trainable_weights)
            #print(grads)
            
        with tf.device('/cpu:0'):
            for weight, grad in zip(model.trainable_weights, grads):
                weight_cpu = tf.Variable(weight)
                weight_cpu.assign_sub(learning_rate*grad)
                weight.assign(weight_cpu)
        
        if step % 50 == 0:
            print("step %d: loss = %.4f" % (step, loss))
    
    test_loss = 0
    test_loss_old = 100
    
    for step, (points, Omega_Omegabar, mass, restriction) in enumerate(test_set):
        omega = volume_form(points, Omega_Omegabar, mass, restriction)
        test_loss += weighted_MAPE(Omega_Omegabar, omega, mass)
   
    test_loss = tf.math.real(test_loss).numpy()/(step+1)
    print("test_loss:", test_loss)
    
    # This part doesn't work right now
    if test_loss > test_loss_old:
        break
    test_loss_old = test_loss

  return float(self._numpy())


step 0: loss = 0.7650
step 50: loss = 0.5645
step 100: loss = 0.4075
step 150: loss = 0.3327
step 200: loss = 0.3220
step 250: loss = 0.2967
step 300: loss = 0.2833
step 350: loss = 0.2539
step 400: loss = 0.2477
step 450: loss = 0.2407
test_loss: 0.23475555419921876
step 0: loss = 0.2331
step 50: loss = 0.2230
step 100: loss = 0.2239
step 150: loss = 0.2119
step 200: loss = 0.2088
step 250: loss = 0.2113
step 300: loss = 0.1965
step 350: loss = 0.2063
step 400: loss = 0.1881
step 450: loss = 0.1842
test_loss: 0.17310842514038086
step 0: loss = 0.1778
step 50: loss = 0.1715
step 100: loss = 0.1797
step 150: loss = 0.1687
step 200: loss = 0.1729
step 250: loss = 0.1554
step 300: loss = 0.1652
step 350: loss = 0.1647
step 400: loss = 0.1803
step 450: loss = 0.1492
test_loss: 0.14740344047546386
step 0: loss = 0.1483
step 50: loss = 0.1568
step 100: loss = 0.1408
step 150: loss = 0.1389
step 200: loss = 0.1481
step 250: loss = 0.1390
step 300: loss = 0.1373
step 350: loss = 0.1312
step 40

step 350: loss = 0.0632
step 400: loss = 0.0680
step 450: loss = 0.0671
test_loss: 0.06661455631256104
step 0: loss = 0.0673
step 50: loss = 0.0636
step 100: loss = 0.0685
step 150: loss = 0.0680
step 200: loss = 0.0668
step 250: loss = 0.0697
step 300: loss = 0.0669
step 350: loss = 0.0665
step 400: loss = 0.0640
step 450: loss = 0.0693
test_loss: 0.06656136512756347
step 0: loss = 0.0669
step 50: loss = 0.0613
step 100: loss = 0.0649
step 150: loss = 0.0683
step 200: loss = 0.0644
step 250: loss = 0.0674
step 300: loss = 0.0672
step 350: loss = 0.0670
step 400: loss = 0.0669
step 450: loss = 0.0650
test_loss: 0.06557498455047607
step 0: loss = 0.0660
step 50: loss = 0.0644
step 100: loss = 0.0664
step 150: loss = 0.0655
step 200: loss = 0.0675
step 250: loss = 0.0640
step 300: loss = 0.0643
step 350: loss = 0.0649
step 400: loss = 0.0617
step 450: loss = 0.0662
test_loss: 0.06756846904754639
step 0: loss = 0.0650
step 50: loss = 0.0672
step 100: loss = 0.0649
step 150: loss = 0.0665


step 100: loss = 0.0535
step 150: loss = 0.0549
step 200: loss = 0.0567
step 250: loss = 0.0565
step 300: loss = 0.0544
step 350: loss = 0.0558
step 400: loss = 0.0568
step 450: loss = 0.0586
test_loss: 0.05580528736114502
step 0: loss = 0.0530
step 50: loss = 0.0581
step 100: loss = 0.0575
step 150: loss = 0.0582
step 200: loss = 0.0546
step 250: loss = 0.0532
step 300: loss = 0.0561
step 350: loss = 0.0562
step 400: loss = 0.0551
step 450: loss = 0.0560
test_loss: 0.05457709789276123
step 0: loss = 0.0545
step 50: loss = 0.0523
step 100: loss = 0.0513
step 150: loss = 0.0582
step 200: loss = 0.0575
step 250: loss = 0.0563
step 300: loss = 0.0540
step 350: loss = 0.0543
step 400: loss = 0.0548
step 450: loss = 0.0563
test_loss: 0.05488643646240234
step 0: loss = 0.0548
step 50: loss = 0.0550
step 100: loss = 0.0577
step 150: loss = 0.0543
step 200: loss = 0.0529
step 250: loss = 0.0544
step 300: loss = 0.0549
step 350: loss = 0.0547
step 400: loss = 0.0560
step 450: loss = 0.0549
test

step 450: loss = 0.0523
test_loss: 0.05127671241760254
step 0: loss = 0.0502
step 50: loss = 0.0503
step 100: loss = 0.0522
step 150: loss = 0.0514
step 200: loss = 0.0510
step 250: loss = 0.0513
step 300: loss = 0.0517
step 350: loss = 0.0509
step 400: loss = 0.0562
step 450: loss = 0.0509
test_loss: 0.05165863037109375
step 0: loss = 0.0495
step 50: loss = 0.0503
step 100: loss = 0.0521
step 150: loss = 0.0510
step 200: loss = 0.0489
step 250: loss = 0.0546
step 300: loss = 0.0531
step 350: loss = 0.0516
step 400: loss = 0.0501
step 450: loss = 0.0506
test_loss: 0.05186009407043457
step 0: loss = 0.0521
step 50: loss = 0.0505
step 100: loss = 0.0505
step 150: loss = 0.0506
step 200: loss = 0.0536
step 250: loss = 0.0540
step 300: loss = 0.0503
step 350: loss = 0.0508
step 400: loss = 0.0512
step 450: loss = 0.0506
test_loss: 0.050396156311035153
step 0: loss = 0.0527
step 50: loss = 0.0507
step 100: loss = 0.0540
step 150: loss = 0.0506
step 200: loss = 0.0516
step 250: loss = 0.0500

step 200: loss = 0.0483
step 250: loss = 0.0482
step 300: loss = 0.0506
step 350: loss = 0.0504
step 400: loss = 0.0503
step 450: loss = 0.0482
test_loss: 0.05014061450958252
step 0: loss = 0.0504
step 50: loss = 0.0482
step 100: loss = 0.0500
step 150: loss = 0.0499
step 200: loss = 0.0488
step 250: loss = 0.0508
step 300: loss = 0.0485
step 350: loss = 0.0492
step 400: loss = 0.0511
step 450: loss = 0.0496
test_loss: 0.05040798187255859
step 0: loss = 0.0494
step 50: loss = 0.0476
step 100: loss = 0.0518
step 150: loss = 0.0516
step 200: loss = 0.0500
step 250: loss = 0.0492
step 300: loss = 0.0489
step 350: loss = 0.0479
step 400: loss = 0.0502
step 450: loss = 0.0507
test_loss: 0.04958394527435303
step 0: loss = 0.0500
step 50: loss = 0.0491
step 100: loss = 0.0475
step 150: loss = 0.0471
step 200: loss = 0.0515
step 250: loss = 0.0514
step 300: loss = 0.0512
step 350: loss = 0.0492
step 400: loss = 0.0498
step 450: loss = 0.0528
test_loss: 0.05033284664154053
step 0: loss = 0.0489

step 0: loss = 0.0494
step 50: loss = 0.0500
step 100: loss = 0.0484
step 150: loss = 0.0483
step 200: loss = 0.0508
step 250: loss = 0.0485
step 300: loss = 0.0490
step 350: loss = 0.0494
step 400: loss = 0.0479
step 450: loss = 0.0485
test_loss: 0.04902827262878418
step 0: loss = 0.0464
step 50: loss = 0.0499
step 100: loss = 0.0476
step 150: loss = 0.0483
step 200: loss = 0.0504
step 250: loss = 0.0484
step 300: loss = 0.0468
step 350: loss = 0.0490
step 400: loss = 0.0474
step 450: loss = 0.0512
test_loss: 0.0487215518951416
step 0: loss = 0.0482
step 50: loss = 0.0472
step 100: loss = 0.0494
step 150: loss = 0.0491
step 200: loss = 0.0496
step 250: loss = 0.0491
step 300: loss = 0.0492
step 350: loss = 0.0490
step 400: loss = 0.0499
step 450: loss = 0.0513
test_loss: 0.04895973682403564
step 0: loss = 0.0472
step 50: loss = 0.0480
step 100: loss = 0.0477
step 150: loss = 0.0482
step 200: loss = 0.0462
step 250: loss = 0.0492
step 300: loss = 0.0489
step 350: loss = 0.0503
step 400

step 350: loss = 0.0475
step 400: loss = 0.0474
step 450: loss = 0.0451
test_loss: 0.0483718204498291
step 0: loss = 0.0489
step 50: loss = 0.0454
step 100: loss = 0.0476
step 150: loss = 0.0449
step 200: loss = 0.0511
step 250: loss = 0.0483
step 300: loss = 0.0493
step 350: loss = 0.0492
step 400: loss = 0.0490
step 450: loss = 0.0474
test_loss: 0.048009943962097165
step 0: loss = 0.0476
step 50: loss = 0.0522
step 100: loss = 0.0464
step 150: loss = 0.0479
step 200: loss = 0.0475
step 250: loss = 0.0487
step 300: loss = 0.0496
step 350: loss = 0.0505
step 400: loss = 0.0467
step 450: loss = 0.0469
test_loss: 0.04831006050109863
step 0: loss = 0.0493
step 50: loss = 0.0498
step 100: loss = 0.0470
step 150: loss = 0.0458
step 200: loss = 0.0504
step 250: loss = 0.0439
step 300: loss = 0.0478
step 350: loss = 0.0479
step 400: loss = 0.0491
step 450: loss = 0.0480
test_loss: 0.048287730216979984
step 0: loss = 0.0455
step 50: loss = 0.0482
step 100: loss = 0.0467
step 150: loss = 0.0459

step 100: loss = 0.0471
step 150: loss = 0.0475
step 200: loss = 0.0469
step 250: loss = 0.0460
step 300: loss = 0.0484
step 350: loss = 0.0490
step 400: loss = 0.0455
step 450: loss = 0.0494
test_loss: 0.047554821968078614
step 0: loss = 0.0458
step 50: loss = 0.0481
step 100: loss = 0.0474
step 150: loss = 0.0458
step 200: loss = 0.0461
step 250: loss = 0.0464
step 300: loss = 0.0492
step 350: loss = 0.0496
step 400: loss = 0.0470
step 450: loss = 0.0475
test_loss: 0.047812151908874514
step 0: loss = 0.0475
step 50: loss = 0.0448
step 100: loss = 0.0490
step 150: loss = 0.0491
step 200: loss = 0.0454
step 250: loss = 0.0501
step 300: loss = 0.0460
step 350: loss = 0.0455
step 400: loss = 0.0471
step 450: loss = 0.0453
test_loss: 0.047446908950805666
step 0: loss = 0.0499
step 50: loss = 0.0489
step 100: loss = 0.0493
step 150: loss = 0.0471
step 200: loss = 0.0465
step 250: loss = 0.0467
step 300: loss = 0.0496
step 350: loss = 0.0490
step 400: loss = 0.0486
step 450: loss = 0.0481
t

step 400: loss = 0.0505
step 450: loss = 0.0479
test_loss: 0.047130393981933597
step 0: loss = 0.0461
step 50: loss = 0.0450
step 100: loss = 0.0470
step 150: loss = 0.0472
step 200: loss = 0.0456
step 250: loss = 0.0471
step 300: loss = 0.0448
step 350: loss = 0.0462
step 400: loss = 0.0503
step 450: loss = 0.0479
test_loss: 0.04742093563079834
step 0: loss = 0.0477
step 50: loss = 0.0476
step 100: loss = 0.0451
step 150: loss = 0.0468
step 200: loss = 0.0451
step 250: loss = 0.0464
step 300: loss = 0.0467
step 350: loss = 0.0483
step 400: loss = 0.0447
step 450: loss = 0.0476
test_loss: 0.047083349227905275
step 0: loss = 0.0481
step 50: loss = 0.0476
step 100: loss = 0.0458
step 150: loss = 0.0471
step 200: loss = 0.0488
step 250: loss = 0.0475
step 300: loss = 0.0464
step 350: loss = 0.0478
step 400: loss = 0.0480
step 450: loss = 0.0489
test_loss: 0.047315669059753415
step 0: loss = 0.0443
step 50: loss = 0.0471
step 100: loss = 0.0477
step 150: loss = 0.0491
step 200: loss = 0.04

step 150: loss = 0.0451
step 200: loss = 0.0494
step 250: loss = 0.0462
step 300: loss = 0.0490
step 350: loss = 0.0496
step 400: loss = 0.0477
step 450: loss = 0.0469
test_loss: 0.0473290491104126
step 0: loss = 0.0462
step 50: loss = 0.0458
step 100: loss = 0.0471
step 150: loss = 0.0480
step 200: loss = 0.0462
step 250: loss = 0.0482
step 300: loss = 0.0459
step 350: loss = 0.0443
step 400: loss = 0.0461
step 450: loss = 0.0477
test_loss: 0.04672904968261719
step 0: loss = 0.0464
step 50: loss = 0.0480
step 100: loss = 0.0468
step 150: loss = 0.0456
step 200: loss = 0.0458
step 250: loss = 0.0472
step 300: loss = 0.0465
step 350: loss = 0.0459
step 400: loss = 0.0459
step 450: loss = 0.0468
test_loss: 0.04687005043029785
step 0: loss = 0.0455
step 50: loss = 0.0483
step 100: loss = 0.0444
step 150: loss = 0.0464
step 200: loss = 0.0447
step 250: loss = 0.0459
step 300: loss = 0.0463
step 350: loss = 0.0502
step 400: loss = 0.0470
step 450: loss = 0.0487
test_loss: 0.0474257469177246

test_loss: 0.046771059036254885
step 0: loss = 0.0454
step 50: loss = 0.0459
step 100: loss = 0.0487
step 150: loss = 0.0477
step 200: loss = 0.0462
step 250: loss = 0.0474
step 300: loss = 0.0456
step 350: loss = 0.0466
step 400: loss = 0.0464
step 450: loss = 0.0472
test_loss: 0.046689324378967285
step 0: loss = 0.0478
step 50: loss = 0.0464
step 100: loss = 0.0447
step 150: loss = 0.0488
step 200: loss = 0.0453
step 250: loss = 0.0462
step 300: loss = 0.0491
step 350: loss = 0.0459
step 400: loss = 0.0458
step 450: loss = 0.0480
test_loss: 0.04673484325408936
step 0: loss = 0.0440
step 50: loss = 0.0446
step 100: loss = 0.0434
step 150: loss = 0.0476
step 200: loss = 0.0436
step 250: loss = 0.0456
step 300: loss = 0.0475
step 350: loss = 0.0461
step 400: loss = 0.0453
step 450: loss = 0.0452
test_loss: 0.04691484451293945
step 0: loss = 0.0456
step 50: loss = 0.0456
step 100: loss = 0.0465
step 150: loss = 0.0463
step 200: loss = 0.0460
step 250: loss = 0.0472
step 300: loss = 0.043

step 250: loss = 0.0444
step 300: loss = 0.0447
step 350: loss = 0.0456
step 400: loss = 0.0446
step 450: loss = 0.0453
test_loss: 0.04558298587799072
step 0: loss = 0.0461
step 50: loss = 0.0447
step 100: loss = 0.0431
step 150: loss = 0.0446
step 200: loss = 0.0464
step 250: loss = 0.0439
step 300: loss = 0.0447
step 350: loss = 0.0436
step 400: loss = 0.0442
step 450: loss = 0.0453
test_loss: 0.045584826469421386
step 0: loss = 0.0456
step 50: loss = 0.0474
step 100: loss = 0.0454
step 150: loss = 0.0450
step 200: loss = 0.0477
step 250: loss = 0.0416
step 300: loss = 0.0429
step 350: loss = 0.0449
step 400: loss = 0.0460
step 450: loss = 0.0455
test_loss: 0.04558544158935547
step 0: loss = 0.0460
step 50: loss = 0.0450
step 100: loss = 0.0441
step 150: loss = 0.0466
step 200: loss = 0.0450
step 250: loss = 0.0458
step 300: loss = 0.0450
step 350: loss = 0.0472
step 400: loss = 0.0442
step 450: loss = 0.0451
test_loss: 0.04577095031738281
step 0: loss = 0.0466
step 50: loss = 0.0479

step 0: loss = 0.0437
step 50: loss = 0.0429
step 100: loss = 0.0435
step 150: loss = 0.0451
step 200: loss = 0.0427
step 250: loss = 0.0441
step 300: loss = 0.0438
step 350: loss = 0.0443
step 400: loss = 0.0449
step 450: loss = 0.0423
test_loss: 0.044403085708618166
step 0: loss = 0.0433
step 50: loss = 0.0454
step 100: loss = 0.0432
step 150: loss = 0.0443
step 200: loss = 0.0445
step 250: loss = 0.0439
step 300: loss = 0.0429
step 350: loss = 0.0453
step 400: loss = 0.0443
step 450: loss = 0.0448
test_loss: 0.044417109489440915
step 0: loss = 0.0419
step 50: loss = 0.0447
step 100: loss = 0.0417
step 150: loss = 0.0446
step 200: loss = 0.0436
step 250: loss = 0.0468
step 300: loss = 0.0465
step 350: loss = 0.0413
step 400: loss = 0.0449
step 450: loss = 0.0456
test_loss: 0.04464056015014648
step 0: loss = 0.0441
step 50: loss = 0.0421
step 100: loss = 0.0436
step 150: loss = 0.0420
step 200: loss = 0.0464
step 250: loss = 0.0422
step 300: loss = 0.0454
step 350: loss = 0.0447
step 

step 350: loss = 0.0435
step 400: loss = 0.0436
step 450: loss = 0.0438
test_loss: 0.04374153137207031
step 0: loss = 0.0443
step 50: loss = 0.0430
step 100: loss = 0.0420
step 150: loss = 0.0439
step 200: loss = 0.0424
step 250: loss = 0.0462
step 300: loss = 0.0431
step 350: loss = 0.0426
step 400: loss = 0.0452
step 450: loss = 0.0427
test_loss: 0.0438374137878418
step 0: loss = 0.0431
step 50: loss = 0.0452
step 100: loss = 0.0448
step 150: loss = 0.0435
step 200: loss = 0.0424
step 250: loss = 0.0447
step 300: loss = 0.0429
step 350: loss = 0.0432
step 400: loss = 0.0460
step 450: loss = 0.0442
test_loss: 0.043320412635803225
step 0: loss = 0.0430
step 50: loss = 0.0422
step 100: loss = 0.0434
step 150: loss = 0.0422
step 200: loss = 0.0436
step 250: loss = 0.0446
step 300: loss = 0.0433
step 350: loss = 0.0456
step 400: loss = 0.0419
step 450: loss = 0.0431
test_loss: 0.04339375019073486
step 0: loss = 0.0407
step 50: loss = 0.0442
step 100: loss = 0.0435
step 150: loss = 0.0437


step 100: loss = 0.0394
step 150: loss = 0.0417
step 200: loss = 0.0412
step 250: loss = 0.0399
step 300: loss = 0.0436
step 350: loss = 0.0427
step 400: loss = 0.0408
step 450: loss = 0.0410
test_loss: 0.0415607738494873
step 0: loss = 0.0427
step 50: loss = 0.0403
step 100: loss = 0.0423
step 150: loss = 0.0418
step 200: loss = 0.0429
step 250: loss = 0.0422
step 300: loss = 0.0415
step 350: loss = 0.0418
step 400: loss = 0.0417
step 450: loss = 0.0429
test_loss: 0.041389236450195314
step 0: loss = 0.0401
step 50: loss = 0.0413
step 100: loss = 0.0413
step 150: loss = 0.0425
step 200: loss = 0.0428
step 250: loss = 0.0419
step 300: loss = 0.0404
step 350: loss = 0.0415
step 400: loss = 0.0409
step 450: loss = 0.0415
test_loss: 0.041487507820129395
step 0: loss = 0.0407
step 50: loss = 0.0419
step 100: loss = 0.0412
step 150: loss = 0.0416
step 200: loss = 0.0392
step 250: loss = 0.0418
step 300: loss = 0.0409
step 350: loss = 0.0420
step 400: loss = 0.0411
step 450: loss = 0.0411
tes

step 450: loss = 0.0421
test_loss: 0.04138040065765381
step 0: loss = 0.0401
step 50: loss = 0.0405
step 100: loss = 0.0385
step 150: loss = 0.0420
step 200: loss = 0.0416
step 250: loss = 0.0427
step 300: loss = 0.0400
step 350: loss = 0.0425
step 400: loss = 0.0440
step 450: loss = 0.0418
test_loss: 0.041513471603393554
step 0: loss = 0.0393
step 50: loss = 0.0411
step 100: loss = 0.0394
step 150: loss = 0.0424
step 200: loss = 0.0406
step 250: loss = 0.0417
step 300: loss = 0.0425
step 350: loss = 0.0399
step 400: loss = 0.0417
step 450: loss = 0.0419
test_loss: 0.041382861137390134
step 0: loss = 0.0411
step 50: loss = 0.0415
step 100: loss = 0.0404
step 150: loss = 0.0413
step 200: loss = 0.0408
step 250: loss = 0.0404
step 300: loss = 0.0403
step 350: loss = 0.0417
step 400: loss = 0.0400
step 450: loss = 0.0410
test_loss: 0.041289634704589843
step 0: loss = 0.0402
step 50: loss = 0.0402
step 100: loss = 0.0412
step 150: loss = 0.0407
step 200: loss = 0.0431
step 250: loss = 0.04

step 200: loss = 0.0405
step 250: loss = 0.0409
step 300: loss = 0.0386
step 350: loss = 0.0405
step 400: loss = 0.0403
step 450: loss = 0.0418
test_loss: 0.04126838684082031
step 0: loss = 0.0399
step 50: loss = 0.0397
step 100: loss = 0.0408
step 150: loss = 0.0424
step 200: loss = 0.0432
step 250: loss = 0.0417
step 300: loss = 0.0416
step 350: loss = 0.0420
step 400: loss = 0.0407
step 450: loss = 0.0409
test_loss: 0.04109864234924317
step 0: loss = 0.0383
step 50: loss = 0.0412
step 100: loss = 0.0401
step 150: loss = 0.0418
step 200: loss = 0.0418
step 250: loss = 0.0404
step 300: loss = 0.0420
step 350: loss = 0.0412
step 400: loss = 0.0411
step 450: loss = 0.0417
test_loss: 0.04140695095062256
step 0: loss = 0.0391
step 50: loss = 0.0420
step 100: loss = 0.0399
step 150: loss = 0.0384
step 200: loss = 0.0438
step 250: loss = 0.0416
step 300: loss = 0.0428
step 350: loss = 0.0435
step 400: loss = 0.0386
step 450: loss = 0.0396
test_loss: 0.041314854621887206
step 0: loss = 0.041

step 0: loss = 0.0399
step 50: loss = 0.0409
step 100: loss = 0.0415
step 150: loss = 0.0413
step 200: loss = 0.0406
step 250: loss = 0.0426
step 300: loss = 0.0419
step 350: loss = 0.0407
step 400: loss = 0.0419
step 450: loss = 0.0407
test_loss: 0.04163354396820068
step 0: loss = 0.0403
step 50: loss = 0.0408
step 100: loss = 0.0431
step 150: loss = 0.0425
step 200: loss = 0.0398
step 250: loss = 0.0403
step 300: loss = 0.0421
step 350: loss = 0.0408
step 400: loss = 0.0406
step 450: loss = 0.0413
test_loss: 0.04138059139251709
step 0: loss = 0.0403
step 50: loss = 0.0403
step 100: loss = 0.0407
step 150: loss = 0.0391
step 200: loss = 0.0410
step 250: loss = 0.0409
step 300: loss = 0.0417
step 350: loss = 0.0417
step 400: loss = 0.0412
step 450: loss = 0.0413
test_loss: 0.0413386058807373
step 0: loss = 0.0391
step 50: loss = 0.0407
step 100: loss = 0.0406
step 150: loss = 0.0405
step 200: loss = 0.0415
step 250: loss = 0.0411
step 300: loss = 0.0425
step 350: loss = 0.0421
step 400

step 350: loss = 0.0429
step 400: loss = 0.0401
step 450: loss = 0.0417
test_loss: 0.04145503044128418
step 0: loss = 0.0415
step 50: loss = 0.0404
step 100: loss = 0.0420
step 150: loss = 0.0400
step 200: loss = 0.0417
step 250: loss = 0.0404
step 300: loss = 0.0429
step 350: loss = 0.0398
step 400: loss = 0.0421
step 450: loss = 0.0385
test_loss: 0.041299266815185545
step 0: loss = 0.0413
step 50: loss = 0.0417
step 100: loss = 0.0407
step 150: loss = 0.0401
step 200: loss = 0.0413
step 250: loss = 0.0397
step 300: loss = 0.0393
step 350: loss = 0.0410
step 400: loss = 0.0419
step 450: loss = 0.0417
test_loss: 0.041299057006835935
step 0: loss = 0.0418
step 50: loss = 0.0428
step 100: loss = 0.0406
step 150: loss = 0.0418
step 200: loss = 0.0414
step 250: loss = 0.0405
step 300: loss = 0.0412
step 350: loss = 0.0405
step 400: loss = 0.0398
step 450: loss = 0.0402
test_loss: 0.04134538173675537
step 0: loss = 0.0431
step 50: loss = 0.0405
step 100: loss = 0.0389
step 150: loss = 0.039

step 100: loss = 0.0401
step 150: loss = 0.0420
step 200: loss = 0.0422
step 250: loss = 0.0408
step 300: loss = 0.0402
step 350: loss = 0.0402
step 400: loss = 0.0424
step 450: loss = 0.0393
test_loss: 0.041292128562927244
step 0: loss = 0.0406
step 50: loss = 0.0427
step 100: loss = 0.0400
step 150: loss = 0.0392
step 200: loss = 0.0401
step 250: loss = 0.0408
step 300: loss = 0.0397
step 350: loss = 0.0408
step 400: loss = 0.0399
step 450: loss = 0.0416
test_loss: 0.04085536956787109
step 0: loss = 0.0394
step 50: loss = 0.0426
step 100: loss = 0.0416
step 150: loss = 0.0397
step 200: loss = 0.0430
step 250: loss = 0.0430
step 300: loss = 0.0380
step 350: loss = 0.0403
step 400: loss = 0.0392
step 450: loss = 0.0407
test_loss: 0.040982961654663086
step 0: loss = 0.0430
step 50: loss = 0.0409
step 100: loss = 0.0410
step 150: loss = 0.0412
step 200: loss = 0.0443
step 250: loss = 0.0406
step 300: loss = 0.0408
step 350: loss = 0.0400
step 400: loss = 0.0434
step 450: loss = 0.0429
te

step 400: loss = 0.0405
step 450: loss = 0.0407
test_loss: 0.04108291625976562
step 0: loss = 0.0400
step 50: loss = 0.0414
step 100: loss = 0.0404
step 150: loss = 0.0387
step 200: loss = 0.0413
step 250: loss = 0.0405
step 300: loss = 0.0398
step 350: loss = 0.0414
step 400: loss = 0.0415
step 450: loss = 0.0393
test_loss: 0.04134847164154053
step 0: loss = 0.0427
step 50: loss = 0.0385
step 100: loss = 0.0396
step 150: loss = 0.0395
step 200: loss = 0.0410
step 250: loss = 0.0392
step 300: loss = 0.0412
step 350: loss = 0.0400
step 400: loss = 0.0421
step 450: loss = 0.0437
test_loss: 0.04109816074371338
step 0: loss = 0.0405
step 50: loss = 0.0396
step 100: loss = 0.0393
step 150: loss = 0.0400
step 200: loss = 0.0426
step 250: loss = 0.0428
step 300: loss = 0.0401
step 350: loss = 0.0421
step 400: loss = 0.0413
step 450: loss = 0.0417
test_loss: 0.04082503795623779
step 0: loss = 0.0415
step 50: loss = 0.0402
step 100: loss = 0.0418
step 150: loss = 0.0430
step 200: loss = 0.0404


step 150: loss = 0.0398
step 200: loss = 0.0409
step 250: loss = 0.0400
step 300: loss = 0.0421
step 350: loss = 0.0425
step 400: loss = 0.0420
step 450: loss = 0.0400
test_loss: 0.04142539501190186
step 0: loss = 0.0404
step 50: loss = 0.0420
step 100: loss = 0.0409
step 150: loss = 0.0412
step 200: loss = 0.0424
step 250: loss = 0.0418
step 300: loss = 0.0415
step 350: loss = 0.0411
step 400: loss = 0.0421
step 450: loss = 0.0415
test_loss: 0.040949230194091794
step 0: loss = 0.0399
step 50: loss = 0.0381
step 100: loss = 0.0430
step 150: loss = 0.0423
step 200: loss = 0.0421
step 250: loss = 0.0405
step 300: loss = 0.0423
step 350: loss = 0.0437
step 400: loss = 0.0380
step 450: loss = 0.0396
test_loss: 0.04104677200317383
step 0: loss = 0.0406
step 50: loss = 0.0410
step 100: loss = 0.0394
step 150: loss = 0.0417
step 200: loss = 0.0417
step 250: loss = 0.0406
step 300: loss = 0.0442
step 350: loss = 0.0393


### Some debugging tests:

In [None]:
@tf.function
def volume_form(x, Omega_Omegabar, mass, restriction):

    kahler_metric = complex_hessian(tf.math.real(model(x)), x)
    volume_form = tf.linalg.det(tf.matmul(restriction, tf.matmul(kahler_metric, restriction, adjoint_b=True)))
    weights = mass / tf.reduce_sum(mass)
    factor = tf.reduce_sum(weights * volume_form / Omega_Omegabar)
    #factor = tf.constant(4.380538, dtype=tf.complex64)
    return  volume_form/factor
    #return factor
for step, (points, Omega_Omegabar, mass, restriction) in enumerate(dataset):
    omega = volume_form(points, Omega_Omegabar, mass, restriction)
    
    weights = mass / tf.reduce_sum(mass)
    print('omega', omega)
    print('OO',Omega_Omegabar)
    print(tf.cast(tf.abs(Omega_Omegabar -  omega), dtype=tf.complex64) / Omega_Omegabar)
   # print(mass/tf.reduce_sum(mass))

In [16]:
for weight in model.trainable_weights:
    print(weight)

<tf.Variable 'Variable:0' shape=(5, 15) dtype=complex64, numpy=
array([[ 2.2768168e+00+1.6352530e+00j,  5.8658237e+00-4.2718673e-01j,
         5.3873414e-01-1.7244281e-01j, -3.9534414e-01+7.2521999e-02j,
        -1.8923687e-02+1.1802919e-02j, -3.8347977e-01-2.2627316e-01j,
        -7.6662725e-01+8.0790210e-01j,  2.6793966e-01+6.7697294e-02j,
        -4.4106621e-01-1.5242220e-01j, -1.6198709e+00-5.7326740e-01j,
         2.8095651e+00+8.1460869e-01j, -1.1343197e+00-2.5883400e+00j,
        -4.6649609e+00+1.1073344e+00j,  1.7060937e-01-2.2714563e-02j,
         8.5827917e-01-7.9530555e-01j],
       [ 1.9180773e+00+4.4432992e-01j, -1.7982172e+00+1.0105304e+00j,
        -6.0589733e+00-8.5629034e-01j,  2.9838247e+00-3.7910095e-01j,
         2.8432586e+00+9.5829725e-01j, -4.8352519e-01-4.6606249e-01j,
        -1.7036586e+00-9.6777737e-01j, -3.2158126e-03+1.6477600e-03j,
         1.7850419e+00+6.1775154e-01j,  4.3784446e-01-8.8729739e-02j,
         2.8238079e-01+1.9228293e-01j, -2.9208440e-01+1.