In [1]:
from hypersurface_tf import *
from generate_h import *
from complexNN import *
import tensorflow as tf
import numpy as np

### Prepare the dataset:

In [2]:
z0, z1, z2, z3, z4 = sp.symbols('z0, z1, z2, z3, z4')
Z = [z0,z1,z2,z3,z4]
f = z0**5 + z1**5 + z2**5 + z3**5 + z4**5 + 0.5*z0*z1*z2*z3*z4
np.random.seed(123)
HS = Hypersurface(Z, f, 10000)
np.random.seed(124)
HS_test = Hypersurface(Z, f, 10000)

In [3]:
train_set = generate_dataset(HS)
test_set = generate_dataset(HS_test)







In [4]:
train_set = train_set.shuffle(500000).batch(1000)
test_set = test_set.shuffle(50000).batch(1000)

### Build the model:

In [14]:
class KahlerPotential(tf.keras.Model):

    def __init__(self):
        super(KahlerPotential, self).__init__()
        self.layer_1 = ComplexDense(5, 15, activation=tf.square, trainable=False)
        self.layer_trans = LinearTrans(15, 15)
        self.layer_2 = ComplexDense(15, 70, activation=tf.square, trainable=True)

        #self.layer_4 = ComplexDense(50, 10, activation=tf.square)
        #self.layer_3 = ComplexDense(10, 15, activation=tf.square)
        self.g = ComplexG(70)
        
    def call(self, inputs):
        x = self.layer_1(inputs)
        x = self.layer_trans(x)
        x = self.layer_2(x)

        #x = self.layer_4(x)
        x = self.g(x)
        x = tf.linalg.diag_part(tf.matmul(x, x, adjoint_b=True))
        x = tf.math.log(x)
        return x

In [15]:
model = KahlerPotential()

In [16]:
@tf.function
def volume_form(x, Omega_Omegabar, mass, restriction):

    kahler_metric = complex_hessian(tf.math.real(model(x)), x)
    volume_form = tf.linalg.det(tf.matmul(restriction, tf.matmul(kahler_metric, restriction, adjoint_b=True)))
    weights = mass / tf.reduce_sum(mass)
    factor = tf.reduce_sum(weights * volume_form / Omega_Omegabar)
    #factor = tf.constant(35.1774, dtype=tf.complex64)
    return volume_form / factor

In [None]:
#optimizer = tf.keras.optimizers.SGD(learning_rate=1e-1)
learning_rate = 1
epochs = 4000

alpha = 0.001
beta1 = 0.9
beta2 = 0.999
eta = 1e-8

m = [0]*len(model.trainable_weights)
v = [0]*len(model.trainable_weights)
t = 0

for epoch in range(epochs):
    for step, (points, Omega_Omegabar, mass, restriction) in enumerate(train_set):
        with tf.GradientTape() as tape:
            
            omega = volume_form(points, Omega_Omegabar, mass, restriction)
            loss = weighted_MAPE(Omega_Omegabar, omega, mass)  
            grads = tape.gradient(loss, model.trainable_weights)
            #print(grads)

        t = t + 1
        i = 0
        for weight, grad in zip(model.trainable_weights, grads):
            
            m[i] = beta1 * m[i] + (1 - beta1) * grad
            v[i] = beta2 * v[i] + (1 - beta2) * tf.multiply(grad, tf.math.conj(grad))
  
            alpha_t = alpha * np.sqrt(1 - beta2**t) / (1 - beta1**t)
            theta = alpha_t * m[i] / (tf.sqrt(v[i]) + eta)
            #print(theta)
            
            i = i + 1
            
            with tf.device('/cpu:0'):            
                weight_cpu = tf.Variable(weight)
                weight_cpu.assign_sub(theta)
                #weight_cpu.assign_sub(learning_rate*grad)
                weight.assign(weight_cpu)
                #print(weight)

    
        if step % 500 == 0:
            print("step %d: loss = %.4f" % (step, loss))
    
    test_loss = 0
    test_loss_old = 100
    
    for step, (points, Omega_Omegabar, mass, restriction) in enumerate(test_set):
        omega = volume_form(points, Omega_Omegabar, mass, restriction)
        test_loss += weighted_MAPE(Omega_Omegabar, omega, mass)
   
    test_loss = tf.math.real(test_loss).numpy()/(step+1)
    print("test_loss:", test_loss)
    
    # This part doesn't work right now
    if test_loss > test_loss_old:
        break
    test_loss_old = test_loss

step 0: loss = 0.0090
test_loss: 0.009374412298202515
step 0: loss = 0.0092
test_loss: 0.009544750452041626
step 0: loss = 0.0089
test_loss: 0.009487475752830505
step 0: loss = 0.0090
test_loss: 0.009657870531082153
step 0: loss = 0.0093
test_loss: 0.009367501139640808
step 0: loss = 0.0088
test_loss: 0.009568567872047424
step 0: loss = 0.0095
test_loss: 0.009541394710540772
step 0: loss = 0.0092
test_loss: 0.009673133492469788
step 0: loss = 0.0092
test_loss: 0.009566249251365662
step 0: loss = 0.0088
test_loss: 0.00975422739982605
step 0: loss = 0.0097
test_loss: 0.009610427021980285
step 0: loss = 0.0096
test_loss: 0.009673517346382141
step 0: loss = 0.0092
test_loss: 0.009682464599609374
step 0: loss = 0.0091
test_loss: 0.009643289446830749
step 0: loss = 0.0094
test_loss: 0.009698024392127991
step 0: loss = 0.0096
test_loss: 0.009433353543281555
step 0: loss = 0.0088
test_loss: 0.009540339708328247
step 0: loss = 0.0090
test_loss: 0.009621804356575012
step 0: loss = 0.0094
test_lo

test_loss: 0.009483100175857543
step 0: loss = 0.0090
test_loss: 0.00948877513408661
step 0: loss = 0.0095
test_loss: 0.00948485791683197
step 0: loss = 0.0087
test_loss: 0.009407698512077331
step 0: loss = 0.0091
test_loss: 0.009586189389228821
step 0: loss = 0.0093
test_loss: 0.009729294776916505
step 0: loss = 0.0098
test_loss: 0.00948222041130066
step 0: loss = 0.0089
test_loss: 0.009574536681175232
step 0: loss = 0.0093
test_loss: 0.009628851413726807
step 0: loss = 0.0090
test_loss: 0.009490644335746765
step 0: loss = 0.0088
test_loss: 0.009498852491378783
step 0: loss = 0.0088
test_loss: 0.009553110599517823
step 0: loss = 0.0093
test_loss: 0.009540852904319764
step 0: loss = 0.0090
test_loss: 0.009486294984817505
step 0: loss = 0.0094
test_loss: 0.009510684013366699
step 0: loss = 0.0090
test_loss: 0.009452998638153076
step 0: loss = 0.0090
test_loss: 0.009489944577217102
step 0: loss = 0.0090
test_loss: 0.009571993350982666
step 0: loss = 0.0090
test_loss: 0.00945563018321991


step 0: loss = 0.0092
test_loss: 0.00936903953552246
step 0: loss = 0.0089
test_loss: 0.009333707094192505
step 0: loss = 0.0085
test_loss: 0.009378741979598999
step 0: loss = 0.0090
test_loss: 0.009603625535964966
step 0: loss = 0.0093
test_loss: 0.009594237208366394
step 0: loss = 0.0089
test_loss: 0.009654552936553954
step 0: loss = 0.0095
test_loss: 0.009554157257080078
step 0: loss = 0.0092
test_loss: 0.009487879872322082
step 0: loss = 0.0089
test_loss: 0.009575966596603393
step 0: loss = 0.0092
test_loss: 0.009478806853294372
step 0: loss = 0.0089
test_loss: 0.009394279718399047
step 0: loss = 0.0089
test_loss: 0.00948985755443573
step 0: loss = 0.0089
test_loss: 0.00936271607875824
step 0: loss = 0.0087
test_loss: 0.009688311219215394
step 0: loss = 0.0096
test_loss: 0.009423916935920715
step 0: loss = 0.0088
test_loss: 0.009611356258392333
step 0: loss = 0.0088
test_loss: 0.009379985332489014
step 0: loss = 0.0088
test_loss: 0.009409704208374024
step 0: loss = 0.0086
test_loss

test_loss: 0.009462733268737793
step 0: loss = 0.0092
test_loss: 0.009521654844284057
step 0: loss = 0.0095
test_loss: 0.00929696261882782
step 0: loss = 0.0087
test_loss: 0.009307093024253844
step 0: loss = 0.0087
test_loss: 0.009666650891304016
step 0: loss = 0.0092
test_loss: 0.00932434320449829
step 0: loss = 0.0089
test_loss: 0.00923300266265869
step 0: loss = 0.0091
test_loss: 0.0093774151802063
step 0: loss = 0.0097
test_loss: 0.009757463932037353
step 0: loss = 0.0091
test_loss: 0.009337984919548036
step 0: loss = 0.0089
test_loss: 0.009407848715782166
step 0: loss = 0.0084
test_loss: 0.009476780295372009
step 0: loss = 0.0088
test_loss: 0.00930970549583435
step 0: loss = 0.0089
test_loss: 0.00939538598060608
step 0: loss = 0.0088
test_loss: 0.00922407627105713
step 0: loss = 0.0089
test_loss: 0.009552507400512696
step 0: loss = 0.0090
test_loss: 0.009350388646125793
step 0: loss = 0.0091
test_loss: 0.00941051721572876
step 0: loss = 0.0088
test_loss: 0.00938973307609558
step 0

step 0: loss = 0.0087
test_loss: 0.009594451785087586
step 0: loss = 0.0093
test_loss: 0.009500896334648132
step 0: loss = 0.0087
test_loss: 0.009165288805961609
step 0: loss = 0.0088
test_loss: 0.009307503700256348
step 0: loss = 0.0085
test_loss: 0.009300997853279114
step 0: loss = 0.0091
test_loss: 0.009229044318199158
step 0: loss = 0.0088
test_loss: 0.009398529529571533
step 0: loss = 0.0090
test_loss: 0.009196272492408753
step 0: loss = 0.0086
test_loss: 0.009397608041763306
step 0: loss = 0.0089
test_loss: 0.009383633732795715
step 0: loss = 0.0089
test_loss: 0.009420974254608154
step 0: loss = 0.0087
test_loss: 0.009312160015106201
step 0: loss = 0.0088
test_loss: 0.009200961589813232
step 0: loss = 0.0088
test_loss: 0.009324052929878235
step 0: loss = 0.0088
test_loss: 0.009254478216171265
step 0: loss = 0.0089
test_loss: 0.009348561763763428
step 0: loss = 0.0088
test_loss: 0.009349888563156128
step 0: loss = 0.0091
test_loss: 0.009564566612243652
step 0: loss = 0.0095
test_l

test_loss: 0.009256112575531005
step 0: loss = 0.0088
test_loss: 0.009180520176887513
step 0: loss = 0.0086
test_loss: 0.009357097148895264
step 0: loss = 0.0086
test_loss: 0.00951068639755249
step 0: loss = 0.0090
test_loss: 0.009380071759223937
step 0: loss = 0.0090
test_loss: 0.009204322695732117
step 0: loss = 0.0087
test_loss: 0.009256626963615417
step 0: loss = 0.0087
test_loss: 0.009222795963287353
step 0: loss = 0.0090
test_loss: 0.009090502262115479
step 0: loss = 0.0083
test_loss: 0.009674981832504273
step 0: loss = 0.0094
test_loss: 0.009190874695777894
step 0: loss = 0.0090
test_loss: 0.009459108710289002
step 0: loss = 0.0090
test_loss: 0.009270002841949463
step 0: loss = 0.0089
test_loss: 0.009254018664360047
step 0: loss = 0.0088
test_loss: 0.009424105882644654
step 0: loss = 0.0087
test_loss: 0.009345014095306397
step 0: loss = 0.0090
test_loss: 0.00915932834148407
step 0: loss = 0.0086
test_loss: 0.009224420785903931
step 0: loss = 0.0085
test_loss: 0.00932754278182983

step 0: loss = 0.0090
test_loss: 0.009468648433685303
step 0: loss = 0.0091
test_loss: 0.00941927671432495
step 0: loss = 0.0090
test_loss: 0.009218792915344238
step 0: loss = 0.0084
test_loss: 0.009240161180496215
step 0: loss = 0.0089
test_loss: 0.009197253584861755
step 0: loss = 0.0081
test_loss: 0.009325242042541504
step 0: loss = 0.0090
test_loss: 0.009442656636238098
step 0: loss = 0.0094
test_loss: 0.009169229865074157
step 0: loss = 0.0083
test_loss: 0.009320728182792663
step 0: loss = 0.0090
test_loss: 0.009339910745620728
step 0: loss = 0.0090
test_loss: 0.00926833152770996
step 0: loss = 0.0089
test_loss: 0.009241839051246642
step 0: loss = 0.0088
test_loss: 0.009051544666290283
step 0: loss = 0.0082
test_loss: 0.009132473468780518
step 0: loss = 0.0086
test_loss: 0.008982700705528259
step 0: loss = 0.0084
test_loss: 0.009107506871223449
step 0: loss = 0.0082
test_loss: 0.0093228942155838
step 0: loss = 0.0087
test_loss: 0.0093459814786911
step 0: loss = 0.0092
test_loss: 0

test_loss: 0.009319987297058106
step 0: loss = 0.0084
test_loss: 0.009001184701919556
step 0: loss = 0.0086
test_loss: 0.009543954133987427
step 0: loss = 0.0091
test_loss: 0.009354785680770875
step 0: loss = 0.0086
test_loss: 0.00927104651927948
step 0: loss = 0.0083
test_loss: 0.009270716309547424
step 0: loss = 0.0093
test_loss: 0.008968934416770935
step 0: loss = 0.0088
test_loss: 0.009141297936439514
step 0: loss = 0.0090
test_loss: 0.009231497645378112
step 0: loss = 0.0090
test_loss: 0.009041218161582947
step 0: loss = 0.0087
test_loss: 0.009231114983558655
step 0: loss = 0.0089
test_loss: 0.009111796021461487
step 0: loss = 0.0083
test_loss: 0.009208092093467712
step 0: loss = 0.0088
test_loss: 0.009208779335021972
step 0: loss = 0.0083
test_loss: 0.009365463852882385
step 0: loss = 0.0089
test_loss: 0.009135335683822632
step 0: loss = 0.0089
test_loss: 0.00906499981880188
step 0: loss = 0.0085
test_loss: 0.009108151793479919
step 0: loss = 0.0085
test_loss: 0.00905484318733215

### Some debugging tests:

In [None]:
@tf.function
def volume_form(x, Omega_Omegabar, mass, restriction):

    kahler_metric = complex_hessian(tf.math.real(model(x)), x)
    volume_form = tf.linalg.det(tf.matmul(restriction, tf.matmul(kahler_metric, restriction, adjoint_b=True)))
    weights = mass / tf.reduce_sum(mass)
    factor = tf.reduce_sum(weights * volume_form / Omega_Omegabar)
    #factor = tf.constant(4.380538, dtype=tf.complex64)
    return  volume_form/factor
    #return factor
for step, (points, Omega_Omegabar, mass, restriction) in enumerate(dataset):
    omega = volume_form(points, Omega_Omegabar, mass, restriction)
    
    weights = mass / tf.reduce_sum(mass)
    print('omega', omega)
    print('OO',Omega_Omegabar)
    print(tf.cast(tf.abs(Omega_Omegabar -  omega), dtype=tf.complex64) / Omega_Omegabar)
   # print(mass/tf.reduce_sum(mass))

In [16]:
for weight in model.trainable_weights:
    print(weight)

<tf.Variable 'Variable:0' shape=(5, 15) dtype=complex64, numpy=
array([[ 2.2768168e+00+1.6352530e+00j,  5.8658237e+00-4.2718673e-01j,
         5.3873414e-01-1.7244281e-01j, -3.9534414e-01+7.2521999e-02j,
        -1.8923687e-02+1.1802919e-02j, -3.8347977e-01-2.2627316e-01j,
        -7.6662725e-01+8.0790210e-01j,  2.6793966e-01+6.7697294e-02j,
        -4.4106621e-01-1.5242220e-01j, -1.6198709e+00-5.7326740e-01j,
         2.8095651e+00+8.1460869e-01j, -1.1343197e+00-2.5883400e+00j,
        -4.6649609e+00+1.1073344e+00j,  1.7060937e-01-2.2714563e-02j,
         8.5827917e-01-7.9530555e-01j],
       [ 1.9180773e+00+4.4432992e-01j, -1.7982172e+00+1.0105304e+00j,
        -6.0589733e+00-8.5629034e-01j,  2.9838247e+00-3.7910095e-01j,
         2.8432586e+00+9.5829725e-01j, -4.8352519e-01-4.6606249e-01j,
        -1.7036586e+00-9.6777737e-01j, -3.2158126e-03+1.6477600e-03j,
         1.7850419e+00+6.1775154e-01j,  4.3784446e-01-8.8729739e-02j,
         2.8238079e-01+1.9228293e-01j, -2.9208440e-01+1.