In [1]:
from hypersurface_tf import *
from generate_h import *
from complexNN import *
import tensorflow as tf

### Prepare the dataset:

In [2]:
z0, z1, z2, z3, z4 = sp.symbols('z0, z1, z2, z3, z4')
Z = [z0,z1,z2,z3,z4]
f = z0**5 + z1**5 + z2**5 + z3**5 + z4**5 + 0.5*z0*z1*z2*z3*z4
HS = Hypersurface(Z, f, 10000)

In [3]:
dataset = None
for patch in HS.patches:
    for subpatch in patch.patches:
        new_dataset = generate_dataset(subpatch)
        if dataset is None:
            dataset = new_dataset
        else:
            dataset = dataset.concatenate(new_dataset)





In [4]:
dataset = dataset.batch(1000)

### Build the model:

In [5]:
class KahlerPotential(tf.keras.Model):

    def __init__(self):
        super(KahlerPotential, self).__init__()
        self.layer_1 = ComplexDense(5, 20, activation=tf.square)
        self.layer_2 = ComplexDense(20, 10, activation=tf.square)
        self.g = ComplexDense(10, 10)
        
    def call(self, inputs):
        x = self.layer_1(inputs)
        x = self.layer_2(x)
        x = self.g(x)
        x = tf.linalg.diag_part(tf.matmul(x, x, adjoint_b=True))
        x = tf.math.log(x)
        return x

In [6]:
model = KahlerPotential()

In [7]:
@tf.function
def volume_form(x, weights, restriction):

    kahler_metric = complex_hessian(tf.math.real(model(x)), x)
    volume_form = tf.linalg.det(tf.matmul(restriction, tf.matmul(kahler_metric, restriction, adjoint_b=True)))

    factor = tf.reduce_sum(weights * volume_form)
    return volume_form / factor

In [29]:
optimizer = tf.keras.optimizers.SGD(learning_rate=1e-1)
learning_rate = 1
epochs = 2

for epoch in range(epochs):
    for step, (points, Omega_Omegabar, weights, restriction) in enumerate(dataset):
        start_time = time.time()
        with tf.GradientTape() as tape:
            
            omega = volume_form(points, weights, restriction)
            
            y_pred = tf.math.real(weights*Omega_Omegabar)
            y_true = tf.math.real(weights*omega)
            
            loss = tf.keras.losses.mean_absolute_percentage_error(y_pred, y_true)
            loss = loss * len(points) / 100
            
        grads = tape.gradient(loss, model.trainable_weights)
        #for weight, grad in zip(model.trainable_weights, grads):
        #    weight.assign_sub(learning_rate * grad)
        optimizer.apply_gradients(zip(grads, model.trainable_weights))
        print(time.time() - start_time)

        if step % 1 == 0:
            print("step %d: loss = %.4f" % (step, loss))

NotFoundError: No registered 'AssignSubVariableOp' OpKernel for 'GPU' devices compatible with node {{node AssignSubVariableOp}}
	 (OpKernel was found, but attributes didn't match) Requested Attributes: dtype=DT_COMPLEX64
	.  Registered:  device='XLA_GPU'; dtype in [DT_FLOAT, DT_DOUBLE, DT_INT32, DT_UINT8, DT_INT16, DT_INT8, DT_COMPLEX64, DT_INT64, DT_BFLOAT16, DT_UINT16, DT_COMPLEX128, DT_HALF, DT_UINT32, DT_UINT64]
  device='XLA_CPU'; dtype in [DT_FLOAT, DT_DOUBLE, DT_INT32, DT_UINT8, DT_INT16, DT_INT8, DT_COMPLEX64, DT_INT64, DT_BFLOAT16, DT_UINT16, DT_COMPLEX128, DT_HALF, DT_UINT32, DT_UINT64]
  device='XLA_GPU_JIT'; dtype in [DT_FLOAT, DT_DOUBLE, DT_INT32, DT_UINT8, DT_INT16, DT_INT8, DT_COMPLEX64, DT_INT64, DT_BFLOAT16, DT_UINT16, DT_COMPLEX128, DT_HALF, DT_UINT32, DT_UINT64]
  device='XLA_CPU_JIT'; dtype in [DT_FLOAT, DT_DOUBLE, DT_INT32, DT_UINT8, DT_INT16, DT_INT8, DT_COMPLEX64, DT_INT64, DT_BFLOAT16, DT_UINT16, DT_COMPLEX128, DT_HALF, DT_UINT32, DT_UINT64]
  device='CPU'; dtype in [DT_COMPLEX128]
  device='CPU'; dtype in [DT_COMPLEX64]
  device='CPU'; dtype in [DT_DOUBLE]
  device='CPU'; dtype in [DT_FLOAT]
  device='CPU'; dtype in [DT_BFLOAT16]
  device='CPU'; dtype in [DT_HALF]
  device='CPU'; dtype in [DT_INT8]
  device='CPU'; dtype in [DT_UINT8]
  device='CPU'; dtype in [DT_INT16]
  device='CPU'; dtype in [DT_UINT16]
  device='CPU'; dtype in [DT_INT32]
  device='CPU'; dtype in [DT_INT64]
  device='GPU'; dtype in [DT_INT64]
  device='GPU'; dtype in [DT_DOUBLE]
  device='GPU'; dtype in [DT_FLOAT]
  device='GPU'; dtype in [DT_HALF]
 [Op:AssignSubVariableOp]

In [10]:
@tf.function
def volume_form(x, weights, restriction):

    kahler_metric = complex_hessian(tf.math.real(model(x)), x)
    volume_form = tf.linalg.det(tf.matmul(restriction, tf.matmul(kahler_metric, restriction, adjoint_b=True)))

    factor = tf.reduce_sum(weights * volume_form)
    return volume_form

In [11]:
for step, (x, y, weights, restriction) in enumerate(dataset):
    #print(volume_form(x, weights, restriction))

    print(volume_form(x, weights, restriction))


tf.Tensor(
[ 2.70501712e-20+4.80829039e-20j -1.47771608e-18+6.58276618e-18j
  2.36415959e-15+1.74314690e-15j  8.43443759e-19-2.34494640e-19j
  6.14645214e-20-9.98295423e-22j  1.61114560e-15+1.57181654e-16j
  1.25637439e-19+2.98170564e-19j  1.75856621e-19-1.16860401e-20j
  1.76794644e-15+2.13091281e-16j  4.25092931e-20-2.11954288e-20j
  9.24989840e-18+5.76841638e-19j  6.32760717e-16+4.23252676e-16j
  1.43000916e-19-9.53798907e-19j  5.37048184e-12+9.37593238e-14j
  1.23963603e-16+2.70719208e-17j  2.02873558e-20-6.92689511e-20j
 -1.60246774e-18+9.79763899e-19j -1.48106059e-20+3.85184381e-19j
 -1.76493395e-19+1.14782787e-19j  3.83217558e-18-6.40094485e-19j
 -2.00006094e-19+1.78710834e-22j  2.13533251e-17-5.69952092e-18j
  3.67048059e-19-1.62029803e-18j  1.06555574e-12-1.04314432e-13j
  7.67608422e-20+3.45627984e-20j  1.73880914e-15-6.76818130e-16j
  9.17479170e-21-1.35317862e-20j  4.46693413e-16+9.60699898e-16j
  2.18049765e-15-1.02763639e-16j  5.11679807e-20-6.48746058e-20j
  5.56306225e-

tf.Tensor(
[-4.65925931e-15-2.17433866e-14j -6.16470763e-20-1.22737525e-19j
  2.49829774e-20+1.69026846e-20j  4.87470055e-16+1.72892156e-16j
 -1.47809245e-18+5.82051263e-18j  8.74213747e-21+1.50845077e-20j
  4.69946580e-22-1.54479469e-20j  1.68439582e-17+3.20971059e-18j
  3.80544544e-19+6.29146918e-20j  2.51960637e-17-3.84781061e-17j
  1.17020878e-20+1.72835253e-20j  1.24196777e-17-9.24590642e-18j
  4.18847825e-18+2.81542554e-19j  1.26207110e-13-7.21924402e-15j
 -3.83114284e-19-4.81056243e-19j  2.99962477e-20-1.54181926e-20j
  2.38579294e-20-2.38211085e-20j  2.31855882e-20-5.81099421e-20j
  1.00682531e-19+1.45504543e-20j -1.76707589e-20+2.76473952e-21j
  5.11213818e-16-1.82570924e-16j -3.98860449e-20+1.36200148e-19j
  6.77105689e-18+2.65213709e-18j  2.79111629e-15-2.53701429e-16j
  2.36817560e-20+1.49721888e-20j -7.94579227e-20-1.04880511e-19j
  4.39837528e-15-8.03714518e-16j  1.46040838e-17+4.27862150e-18j
 -4.46493173e-19-3.50061155e-19j  1.11642293e-19-4.21402478e-21j
 -2.05695149e-

tf.Tensor(
[-7.71781524e-22+1.07376025e-20j  8.22559201e-16-1.61882806e-16j
 -2.05427618e-15+2.36816727e-16j -1.79691068e-19-3.37344126e-20j
  4.57397554e-14-1.04676188e-14j -7.53467317e-20-3.98628095e-20j
  1.48364536e-20-1.79699463e-20j  9.46512079e-20+1.07579736e-21j
  2.67886131e-17+1.49309727e-16j  7.91133220e-16-1.29039911e-16j
 -4.76232797e-19+2.42106799e-17j  5.68647897e-19-2.08795521e-19j
  2.11070465e-20+7.81921070e-21j  2.83688079e-19-2.34242040e-19j
 -2.00383130e-20+7.58602412e-21j  6.42331026e-19+1.81422008e-19j
  4.26506607e-18+3.52089014e-18j  9.96465309e-17+2.88320255e-17j
  2.86637474e-14-2.25315660e-14j  2.95386673e-16-1.17480399e-16j
  1.40836159e-19-1.30723643e-19j -1.14950084e-18-1.09528526e-18j
  5.54472503e-19-3.74655768e-19j  8.38651159e-17-6.75993214e-18j
  1.32920312e-19-9.51219318e-20j  8.63150036e-15-4.74533988e-15j
 -6.75475316e-18-9.63590388e-18j -2.97652865e-20+1.55945049e-20j
  1.53679157e-15+4.87748517e-16j -5.82654432e-21-1.82005471e-20j
 -2.89868463e-

tf.Tensor(
[-5.55104784e-22+9.81829864e-22j -1.27158938e-22+6.86504978e-22j
  7.81055103e-17+5.93167511e-17j  8.49331152e-20-9.79617236e-21j
 -7.47339783e-20-3.42512195e-20j -2.14351884e-20-1.18550039e-20j
  1.42146192e-21-6.07431227e-22j  8.21285555e-19-4.63841198e-18j
  8.13678375e-20+7.31226692e-20j -1.37378996e-19+1.97030364e-19j
  1.26094112e-21-1.32456456e-21j -9.96433677e-18-7.41232659e-18j
 -4.67182280e-21+5.41542344e-21j  4.62374317e-20+3.92379586e-20j
 -8.00647243e-20-8.28904638e-20j -2.31761649e-21+1.46731147e-21j
 -3.23435607e-19-1.70792262e-20j  8.46599182e-15+8.42932122e-15j
  5.69516954e-18+1.04755857e-17j  9.13574942e-21-8.79380798e-21j
  6.57666696e-19-2.19865899e-19j  2.09503127e-17-1.69045211e-17j
 -1.39832750e-19+1.76749271e-20j  6.33881104e-14-1.83789649e-14j
  7.12824289e-19+1.85974211e-18j -5.76010115e-19-4.89733283e-20j
 -8.72682211e-22-5.69867706e-21j -2.14239262e-20-1.28315771e-20j
  4.70402347e-21+3.59342041e-21j  6.07757927e-19-5.16579204e-19j
 -2.71185832e-

tf.Tensor(
[-1.86234355e-17-4.23926590e-17j  2.93558148e-18-5.01907015e-18j
  5.44026543e-20-1.54000687e-19j  3.04442597e-20-2.50526367e-20j
  4.00549287e-20-4.74765127e-20j -6.56879957e-21+1.57581202e-20j
 -2.29777804e-17+1.63550730e-17j  9.18386374e-18+5.54942197e-18j
 -1.30515183e-18-1.02613099e-18j  6.49371233e-21-2.75803420e-19j
 -2.23840183e-21+3.05379694e-21j  6.68716899e-19-1.11195020e-19j
 -1.46752067e-20-3.83663881e-20j  2.17031004e-16-2.91742532e-17j
  1.53979322e-19-1.66858977e-18j  5.87783945e-20+4.30314412e-20j
  2.97572054e-20+5.56623082e-20j  2.72396648e-15-1.89582755e-16j
  1.02162472e-18+6.15336349e-19j  3.22320816e-12-2.06913710e-13j
  1.37515791e-19+8.12756069e-20j  1.01467984e-20+5.39915245e-21j
 -3.85886622e-21-7.57663160e-22j -1.31534397e-19-2.99009584e-19j
 -3.80779489e-19+2.90457519e-19j -4.97474072e-19-1.01030217e-18j
  2.26252364e-18-3.89756622e-19j  1.00367237e-16+5.61779912e-17j
  1.16674045e-19+1.86635107e-19j  4.17654565e-14+9.34114541e-15j
  2.24923617e-

tf.Tensor(
[ 1.61478370e-20-5.25251805e-20j -1.90762739e-19-2.58613592e-20j
  1.59396505e-18-2.62526592e-20j  3.50802719e-16+7.29776721e-17j
  7.52819309e-18-1.13747037e-17j  1.56013027e-16-1.96082072e-17j
  1.52876354e-19-1.45398468e-18j  1.36460115e-19+1.95628460e-19j
  6.88924083e-15-6.10556468e-15j  5.43657284e-20+1.01868210e-19j
 -3.45184489e-21-2.52133133e-20j  3.67949202e-21-2.14187466e-20j
  1.23717624e-17+2.73902976e-17j -3.99857137e-20+1.04336303e-21j
  6.30961739e-19+2.42468403e-19j  3.35770834e-20+3.57708828e-20j
  4.16202872e-20+5.63790473e-21j -4.62147824e-19-5.65476306e-20j
  3.50689360e-21-5.56734961e-21j  3.73049571e-21+4.48275548e-21j
  8.12072094e-20-1.95995502e-20j  1.96272220e-20-1.30596156e-20j
  4.85933099e-20+4.61303312e-20j  2.36039064e-19+5.28349305e-20j
  3.40685240e-21+2.31999403e-21j  2.41884181e-19-1.30874707e-19j
 -1.13312202e-19-2.42390674e-20j  2.60195979e-20+3.20044075e-21j
  1.67445807e-17-2.08270885e-18j  5.80156694e-21+2.80891391e-21j
  2.99135247e-

tf.Tensor(
[-4.89978820e-20-4.01467121e-19j  6.50628152e-20+1.40094398e-19j
 -2.62571239e-21-7.46382020e-22j -3.96170076e-20+1.88904868e-20j
  3.70151886e-20+2.06898809e-20j -2.31413056e-19-2.48000535e-19j
 -3.86444662e-20+1.81941142e-21j  7.15882469e-18+2.39733795e-18j
  2.14160431e-18-1.61895014e-20j -3.27344733e-19+1.55165186e-18j
 -2.18929505e-19+9.24150156e-20j  1.09413942e-20-7.44936048e-20j
 -9.90239521e-19-1.00485591e-18j -9.88990706e-18-4.60129142e-18j
  1.74765725e-19-1.99134556e-20j -8.77155892e-20+5.03903404e-21j
 -5.57902369e-20-2.58992289e-21j  2.04193832e-20+5.71974876e-21j
  3.13150893e-18+1.12243149e-19j  2.34227512e-19+1.36862267e-20j
 -4.58998955e-19-2.82384805e-19j -1.08486754e-19+8.71946529e-20j
  1.02955564e-17+4.46557838e-18j  1.38218132e-19+9.04712574e-20j
  2.62854847e-20+8.00808414e-20j  1.65575368e-18-1.27026921e-18j
 -8.27359490e-19+6.11518600e-20j  7.28014789e-19-1.15442399e-19j
  1.22737642e-16-2.90767981e-17j  3.42302992e-11-1.28627603e-11j
 -4.23062326e-

tf.Tensor(
[ 7.41089048e-14-6.41404773e-15j  5.36182675e-17-2.06269529e-17j
  4.65494335e-17-6.88989106e-17j  7.50897107e-18-6.59025216e-18j
  3.97685206e-20+8.39205234e-20j  1.47787198e-17+2.11490528e-17j
 -1.20738467e-18-2.00580173e-18j  1.22655609e-20-4.99588090e-21j
  7.03928246e-20+3.93580213e-19j -5.15132179e-21-2.39506572e-21j
  5.26436737e-15+4.20969149e-15j  1.18766076e-16-4.55021599e-17j
 -1.21445963e-20-2.58094063e-21j  4.14141053e-19-5.73339239e-20j
 -1.00293788e-18+9.56426549e-20j  5.23730811e-18-2.03560794e-18j
 -4.38095144e-20+4.99889123e-20j -3.46215773e-19+1.67411494e-19j
  1.11829397e-21+7.30525931e-21j  7.87199290e-20+1.35006410e-19j
  1.63263203e-09-4.78445027e-11j  7.71655985e-20+1.83082535e-19j
 -3.08385521e-19-2.96507518e-19j  3.10497587e-15-9.36863571e-17j
  3.09514991e-20-9.62110480e-18j  1.63642547e-17+1.77674177e-18j
 -4.92073591e-21-7.15685966e-22j  1.87492661e-19+1.07763914e-19j
  3.63376646e-21+5.25382827e-22j -4.68872467e-20-3.94046032e-20j
 -1.95201907e-

tf.Tensor(
[ 1.94317807e-15-9.84350752e-16j -1.46233718e-21-1.15684076e-20j
  1.75825564e-16-4.47959429e-17j  6.21177743e-20-6.02303227e-20j
 -4.40834689e-19-1.72515419e-19j -4.59876676e-21-1.04445826e-20j
  6.18884191e-20+2.33106072e-20j  1.36892527e-20+8.58073547e-21j
  1.16042247e-19+5.65655313e-20j -8.83177913e-21-1.48618103e-20j
 -7.25234673e-20-4.84007513e-20j -1.64520791e-21+3.92031790e-21j
  1.96668758e-17-4.34180817e-18j  4.78837764e-20-6.65218261e-20j
  4.63142065e-18+9.84583713e-21j  6.76188485e-20-2.03835559e-21j
  7.21755393e-21-7.90381657e-21j  1.41632693e-17+5.45393761e-18j
  1.45505487e-17-2.37652696e-17j  3.49554243e-18+4.38808975e-19j
 -3.75308161e-20-1.97273266e-21j  7.24470668e-15-1.83750584e-15j
  1.04046209e-19+2.19620685e-20j -3.20335489e-19-6.68215886e-20j
 -2.27274336e-19+6.43194836e-20j  5.16536552e-19+2.83544880e-20j
  5.53106595e-20-5.19812414e-20j  2.82659441e-20+1.01144893e-21j
  1.72294773e-18-1.43706171e-18j  2.45531537e-20-2.18523573e-20j
  4.95286218e-

In [25]:
1/((1.5567119399999998-0.26916985000000004j))

(0.6237315291662895+0.1078489332753244j)

In [8]:
tf.linalg.diag_part(tf.matmul(a, a, adjoint_b=True))

<tf.Tensor: shape=(1000,), dtype=complex64, numpy=
array([6.70346916e-01+7.24273974e-09j, 8.74518299e+00+1.78387154e-08j,
       3.74678516e+00-1.00850492e-08j, 3.47875096e-02+6.86494861e-10j,
       1.02022696e+00-9.93402693e-10j, 1.71032678e-02+1.08401899e-10j,
       6.28842935e-02-2.14205598e-10j, 1.90958166e+00-5.43127499e-09j,
       6.51590824e+00+5.19976879e-08j, 2.88249826e+00-2.74261680e-08j,
       1.30466568e+02-7.50898835e-07j, 1.75445004e+01+1.53152087e-08j,
       1.19553680e+01-3.64723576e-08j, 2.36496496e+00+4.62482603e-08j,
       2.06168437e+00-6.85564316e-09j, 1.42739105e+01-4.63187462e-08j,
       8.59186096e+01-1.82804797e-07j, 2.84091549e+01+3.87515456e-07j,
       2.94046474e+00+2.87820967e-09j, 3.74946861e+01-6.23449523e-08j,
       5.85808182e+00-7.48563878e-10j, 1.58361988e+01+1.10382871e-07j,
       5.97979546e+00-5.23488666e-08j, 6.42857107e-04-4.77533898e-12j,
       1.80393849e-02+2.31761277e-11j, 4.54012841e-01+5.31899325e-09j,
       1.68039703e+02+1.60

In [24]:
a

<tf.Tensor: shape=(1000, 1), dtype=complex64, numpy=
array([[ 1.17939472e-01+7.82026577e+00j],
       [ 1.92773247e+01-1.89247112e+01j],
       [ 1.81673107e+01+1.20645771e+01j],
       [ 1.25924671e+00-2.55466151e+00j],
       [-6.39472198e+00-3.22579908e+00j],
       [-1.01483130e+00-2.02565193e-01j],
       [ 3.14997339e+00-1.09767938e+00j],
       [ 7.13509369e+00+7.29912043e+00j],
       [-3.17161632e+00+5.29501009e+00j],
       [ 7.93774176e+00-1.73669136e+00j],
       [ 1.66230011e+01-7.72204161e+00j],
       [ 8.53851414e+00-9.65203822e-01j],
       [ 2.44792533e+00-1.20293272e+00j],
       [-2.48707390e+00+2.39707017e+00j],
       [-7.53825784e-01+1.04212797e+00j],
       [-2.35443068e+00+8.15270305e-01j],
       [ 4.05000150e-01-1.56702161e+00j],
       [-6.37256050e+00+7.60447359e+00j],
       [-6.54893970e+00+4.68202066e+00j],
       [ 2.13621883e-04-5.80288935e-03j],
       [ 1.19924259e+01+2.56329098e+01j],
       [ 2.59734344e+00+9.43338108e+00j],
       [-3.41419053e+00