In [1]:
import tensorflow as tf
import tensornetwork as tn
import numpy as np
import sys

sys.path.insert(0, "../")
import tensorcircuit as tc
import jax
import torch

In this small demo, we show how to write code that can be run on different backends and dtypes, such as tensorflow, jax and pytorch.
This is just for fun, and I have no intention to agrressively going forward in this direction.

It is just a fancy byproduct with the project tensorcircuit whose focus is quantum computation instead of ML system design.

In [4]:
def universal_code(a, b):
    """
    The code inside this fuction is wirtten once and run everywhere on different dtypes and backends
    """

    @tc.backend.jit
    def add(a, b):
        y = tc.backend.trace(a + b)
        return tc.backend.real(y)

    padd = tc.backend.vmap(add)

    def forward(a, b):
        l = padd(a, b)
        return tc.backend.einsum("i->", l)

    g = tc.backend.grad(forward, argnums=[0, 1])
    a, b = tc.gates.num_to_tensor(a, b)
    return g(a, b)

In [5]:
tc.set_backend("jax")
universal_code(np.ones([2, 2, 2]), np.ones([2, 2, 2]))

(DeviceArray([[[1.+0.j, 0.+0.j],
               [0.+0.j, 1.+0.j]],
 
              [[1.+0.j, 0.+0.j],
               [0.+0.j, 1.+0.j]]], dtype=complex64),
 DeviceArray([[[1.+0.j, 0.+0.j],
               [0.+0.j, 1.+0.j]],
 
              [[1.+0.j, 0.+0.j],
               [0.+0.j, 1.+0.j]]], dtype=complex64))

In [6]:
tc.set_backend("pytorch")
tc.set_dtype("float64")
universal_code(np.ones([2, 2, 2]), np.ones([2, 2, 2]))

  "pytorch backend has no intrinsic vmap like interface"


[tensor([[[1., 0.],
          [0., 1.]],
 
         [[1., 0.],
          [0., 1.]]], dtype=torch.float64), tensor([[[1., 0.],
          [0., 1.]],
 
         [[1., 0.],
          [0., 1.]]], dtype=torch.float64)]

In [9]:
tc.set_backend("tensorflow")
tc.set_dtype("float64")
universal_code(np.ones([2, 2, 2]), np.ones([2, 2, 2]))

[<tf.Tensor: shape=(2, 2, 2), dtype=float64, numpy=
 array([[[1., 0.],
         [0., 1.]],
 
        [[1., 0.],
         [0., 1.]]])>, <tf.Tensor: shape=(2, 2, 2), dtype=float64, numpy=
 array([[[1., 0.],
         [0., 1.]],
 
        [[1., 0.],
         [0., 1.]]])>]

In [8]:
tc.set_backend("tensorflow")
tc.set_dtype("complex128")
universal_code(np.ones([2, 2, 2]), np.ones([2, 2, 2]))

[<tf.Tensor: shape=(2, 2, 2), dtype=complex128, numpy=
 array([[[1.+0.j, 0.+0.j],
         [0.+0.j, 1.+0.j]],
 
        [[1.+0.j, 0.+0.j],
         [0.+0.j, 1.+0.j]]])>,
 <tf.Tensor: shape=(2, 2, 2), dtype=complex128, numpy=
 array([[[1.+0.j, 0.+0.j],
         [0.+0.j, 1.+0.j]],
 
        [[1.+0.j, 0.+0.j],
         [0.+0.j, 1.+0.j]]])>]