Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

MPS circuit with JIT compilation #204

Closed
Muzhou-Ma opened this issue Jan 28, 2024 · 8 comments
Closed

MPS circuit with JIT compilation #204

Muzhou-Ma opened this issue Jan 28, 2024 · 8 comments

Comments

@Muzhou-Ma
Copy link

Issue Description

When I'm using tensorflow as backend and construct MPSCircuit, I use jit for compilation. But it seems that JIT is not working and it has warning like WARNING:tensorflow:Using a while_loop for converting Qr cause there is no registered converter for this op. & WARNING:tensorflow:Using a while_loop for converting SVD cause there is no registered converter for this op..

Proposed Solution

Make operators in MPScircuit Jitable

@refraction-ray
Copy link
Contributor

please attach a reproduce demo

@Muzhou-Ma
Copy link
Author

Muzhou-Ma commented Apr 23, 2024

Hi, here is a reproduce demo:

import tensorcircuit as tc
import tensorflow as tf
import numpy as np

tc.set_backend("tensorflow")
tc.set_dtype("complex64")
def Hamiltonian(c: tc.MPSCircuit, n: int):
    e = 0.0
    for i in range(n):
        e += -1 * tf.cast(c.expectation_ps(z=[i]), tf.float64)
    return -tc.backend.real(e)


def vqe(params, n):
    circuit = tc.MPSCircuit(n)
    circuit.set_split_rules({"max_singular_values": 50})
    
    for i in range(n):
        circuit.rx(i,theta=params[i][0])
        circuit.ry(i,theta=params[i][1])
        circuit.rz(i,theta=params[i][2])
    
    energy = Hamiltonian(circuit, n)
    return energy

vqe_vvag = tc.backend.jit(
    tc.backend.vectorized_value_and_grad(vqe, vectorized_argnums = (0,)), static_argnums=(1,)
)




if __name__=="__main__":
    batch = 16
    n = 8
    maxiter = 100
    params = tf.Variable(
            initial_value=tf.concat(
                [tf.random.normal(shape=[int(batch/4), n, 3], mean=0, stddev=0.2, dtype=getattr(tf, tc.rdtypestr)),
                tf.random.normal(shape=[int(batch/4), n, 3], mean=np.pi/4, stddev=0.2, dtype=getattr(tf, tc.rdtypestr)),
                tf.random.normal(shape=[int(batch/4), n, 3], mean=np.pi/2, stddev=0.2, dtype=getattr(tf, tc.rdtypestr)),
                tf.random.normal(shape=[int(batch/4), n, 3], mean=np.pi*3/4, stddev=0.2, dtype=getattr(tf, tc.rdtypestr))
                ],0)
        )
    opt = tf.keras.optimizers.legacy.Adam(1e-2)
    for i in range(maxiter):
        energy, grad = vqe_vvag(params, n)
        opt.apply_gradients([(grad, params)])
        print(energy)

Thanks a lot!

@refraction-ray
Copy link
Contributor

Thanks for providing the demo, but I can successfully run your demo with no error, my environment info attached below

>>> tc.about()
OS info: macOS-10.15.7-x86_64-i386-64bit
Python version: 3.10.0
Numpy version: 1.24.3
Scipy version: 1.10.1
Pandas version: 2.0.3
TensorNetwork version: 0.5.0
Cotengra version: 0.6.0
TensorFlow version: 2.13.0
TensorFlow GPU: []
TensorFlow CUDA infos: {'is_cuda_build': False, 'is_rocm_build': False, 'is_tensorrt_build': False}
Jax version: 0.4.14
Jax installation doesn't support GPU
JaxLib version: 0.4.14
PyTorch version: 2.0.1
PyTorch GPU support: False
PyTorch GPUs: []
Cupy is not installed
Qiskit version: 0.45.1
Cirq version: 1.2.0
TensorCircuit version 0.12.0

@refraction-ray
Copy link
Contributor

refraction-ray commented Apr 23, 2024

Ah, you mean the warning, I indeed see the warning but I believe it doesn't affect the results. I will further investigate whether the warning has negative effect on jit or whether we can get rid of the warning.

Have checked now! The warning is not related to jit but to vmap. If we use value_and_grad instead of vvag, the warning is gone. The reason for the warning is that there is no vectorized implementation for QR in tensorflow.

@refraction-ray
Copy link
Contributor

If you feel tf is not fast enough, you can always try the following snippet for your actual circuit and hyperparameters, to determine which backend is more suitable (tf vs. jax)

import tensorcircuit as tc
import numpy as np
import time

tc.set_dtype("complex64")


def Hamiltonian(c: tc.MPSCircuit, n: int):
    e = 0.0
    for i in range(n):
        e += -1 * c.expectation_ps(z=[i])
    return -tc.backend.real(e)


def vqe(params, n):
    circuit = tc.MPSCircuit(n)
    circuit.set_split_rules({"max_singular_values": 50})

    for i in range(n):
        circuit.rx(i, theta=params[i][0])
        circuit.ry(i, theta=params[i][1])
        circuit.rz(i, theta=params[i][2])
    for i in range(n-1):
        circuit.cx(i, i+1)

    energy = Hamiltonian(circuit, n)
    return energy


if __name__ == "__main__":
    batch = 16
    n = 16
    maxiter = 100
    params0 = np.random.uniform(size=[batch, n, 3])

    for b in ["tensorflow", "jax"]:
        with tc.runtime_backend(b):
            vqe_vvag = tc.backend.jit(
                tc.backend.vectorized_value_and_grad(vqe, vectorized_argnums=(0,)),
                static_argnums=(1,),
            )
            print("benchmarking backend: %s" % b)
            time0 = time.time()
            params = tc.backend.convert_to_tensor(params0)
            energy, grad = vqe_vvag(params, n)
            print(energy, grad)
            print("jit time", time.time() - time0)
            time0 = time.time()
            for _ in range(5):
                energy, grad = vqe_vvag(params, n)
            print("running time", (time.time() - time0) / 5)

@Muzhou-Ma
Copy link
Author

Aha, I see. Thanks a lot!
So it seems that we can't use vvag for speeding up with tf as backend.

@Muzhou-Ma
Copy link
Author

I will close this issue, many thanks!

@refraction-ray
Copy link
Contributor

Aha, I see. Thanks a lot! So it seems that we can't use vvag for speeding up with tf as backend.

For this point, I dont know. Maybe you can have some microbenchmarks on vvag over batch vs. naive for loop with tf backend. It is also possible that other operations are vectorized which may still be more efficient that a for loop.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants