In [1]:
from hypersurface_tf import *
from generate_h import *
from complexNN import *
import tensorflow as tf

### Prepare the dataset:

In [2]:
z0, z1, z2, z3, z4 = sp.symbols('z0, z1, z2, z3, z4')
Z = [z0,z1,z2,z3,z4]
f = z0**5 + z1**5 + z2**5 + z3**5 + z4**5 + 0.5*z0*z1*z2*z3*z4
HS = Hypersurface(Z, f, 1000)

In [3]:
HS.set_k(2)

In [4]:
patch = HS.patches[0].patches[0]

In [5]:
dataset = generate_dataset(patch)

In [6]:
dataset = dataset.batch(200)

### Build the model:

In [84]:
class KahlerPotential(tf.keras.Model):

    def __init__(self):
        super(KahlerMetric, self).__init__()
        self.layer_1 = ComplexDense(5, 30, activation=tf.square)
        self.layer_2 = ComplexDense(30, 10, activation=tf.square)
        self.g = ComplexDense(10, 10)
        
    def call(self, inputs):
       # with tf.GradientTape(watch_accessed_variables=False, persistent=True) as t:
       #     t.watch(inputs)

        x = self.layer_1(inputs)
        x = self.layer_2(x)
        x = self.g(x)
        x = tf.linalg.diag_part(tf.matmul(x, x, adjoint_b=True))
        x = tf.math.log(x)
        #kahler_metric = complex_hessians(tf.math.real(x), inputs)
        return x

In [70]:
model = KahlerPotential()

In [106]:
@tf.function
def volume_form(x, restriction):
    kahler_metric = tf.reduce_sum(complex_hessians(tf.math.real(model(x)), x), axis=3)[0]
    volume_form = tf.linalg.det(tf.matmul(restriction, tf.matmul(kahler_metric, restriction, adjoint_b=True)))
    return volume_form

In [107]:
start_time = time.time()
for step, (x_batch_train, y_batch_train, restriction) in enumerate(dataset):
    a = volume_form(x_batch_train, restriction)
    
print(time.time() - start_time)
        #kahler_metric = model(x_batch_train)
        #print(kahler_metric)
        #volume_form = tf.matmul(patch.r_tf, tf.matmul(kahler_metric, patch.r_tf, adjoint_b=True))

2.3155465126037598


In [108]:
a

<tf.Tensor: shape=(78,), dtype=complex64, numpy=
array([1.92081884e-01-1.94594364e-07j, 6.12835944e-01+2.14437065e-07j,
       4.31225671e-05-1.70882947e-10j, 8.04232713e-03-5.00550756e-09j,
       2.56405654e-03-1.33626155e-09j, 6.91423076e-04-2.28844287e-11j,
       9.30337701e-04+1.30397970e-09j, 5.05215935e-02-9.45321830e-08j,
       1.67884640e-02-8.55644355e-09j, 5.47784615e+00-1.80061409e-06j,
       2.66619891e-01-1.79391151e-07j, 7.55212456e-03-4.82645968e-09j,
       3.32566500e-02+4.36973089e-08j, 8.62117186e-02-8.02302793e-08j,
       1.12556489e-04-4.88629580e-12j, 7.02707672e+00-5.99435225e-07j,
       1.83564560e+02-8.54613006e-07j, 4.00650874e-02+1.05031646e-07j,
       4.38628107e-01+1.55132597e-07j, 7.17808083e-02-8.68636718e-08j,
       3.43019213e-03+5.39807476e-10j, 7.72475905e-04+5.95114402e-10j,
       8.08639005e-02-9.18419740e-09j, 1.62968084e-01-2.69058038e-08j,
       5.64808615e-05+2.69290867e-10j, 6.30953728e-05+2.31170534e-11j,
       1.98989227e-01-1.5511

### A test for complex_hessians()

In [97]:
x = tf.constant([0.1j, 0.2j],dtype=tf.complex64)

In [98]:
@tf.function
def c_hessians(x):
    return complex_hessians(tf.square(tf.math.real(tf.tensordot(x,tf.math.conj(x),axes=1))), x)

In [99]:
c_hessians(x)

[<tf.Tensor: shape=(2, 2), dtype=complex64, numpy=
 array([[0.12000001+0.j, 0.04      +0.j],
        [0.04      +0.j, 0.18      +0.j]], dtype=complex64)>]