# Distributed Circuit Simulation and TensorNetwork Contraction

## Overview

Simulating large quantum circuits or computing expectation values for complex Hamiltonians often involves contracting a massive tensor network. The computational cost (both time and memory) of this contraction can be a significant bottleneck, especially for systems with many qubits.

TensorCircuit provides an experimental feature, `DistributedContractor`, designed to tackle this challenge. It leverages multiple devices (e.g., GPUs) to parallelize the tensor network contraction. The core idea is:

1.  **Pathfinding with `cotengra`**: It first uses the powerful `cotengra` library to find an optimal or near-optimal contraction path for the given tensor network. This path often involves "slicing" the network, which breaks the single large contraction into many smaller, independent contractions.
2.  **Distributed Computation**: It then distributes these smaller contraction tasks across all available devices. Each device computes a subset of the slices in parallel.
3.  **Aggregation**: Finally, the results from all devices are aggregated to produce the final value (e.g., an expectation value or a state amplitude).

This approach allows us to tackle much larger problems than would be possible on a single device, significantly reducing the wall-clock time for expensive computations.

In this tutorial, we will demonstrate how to use `DistributedContractor` for two common tasks:
-   Calculating the amplitude of a specific bitstring for a large quantum state.
-   Running a Variational Quantum Eigensolver (VQE) optimization for a transverse-field Ising model.

## Setup

First, let's configure JAX to use multiple (virtual) devices and import the necessary libraries.


In [1]:
import os

# Set this for multiple virtual devices
NUM_DEVICES = 4
os.environ["XLA_FLAGS"] = f"--xla_force_host_platform_device_count={NUM_DEVICES}"

import time
import jax
from jax import numpy as jnp
import numpy as np
import optax
import tensorcircuit as tc
from tensorcircuit.experimental import DistributedContractor

K = tc.set_backend("jax")
tc.set_dtype("complex64")

# Verify that JAX sees the configured number of devices
print(f"JAX is using {jax.local_device_count()} devices.")

JAX is using 4 devices.


## `DistributedContractor`: Mechanism

Before diving into examples, let's understand the core components and the inner workings of the `DistributedContractor`.

### The `nodes_fn` Template

The central requirement for using `DistributedContractor` is to provide a function, which we conventionally call `nodes_fn`. This function serves as a template that defines the structure of the tensor network.

-   **Input**: `nodes_fn` must accept a single argument, typically a dictionary or a JAX PyTree of parameters (`params`). These parameters are what you would typically vary in your computation (e.g., the variational parameters of a circuit).
-   **Output**: It must return the list of tensors nodes (`tc.Gate` or `tn.Node` objects, which contain tensors) that constitute the tensor network *before the final contraction*. `tensorcircuit-ng` provides convenient methods like `.expectation_before()` and `.amplitude_before()` for this purpose.

The `DistributedContractor` calls this `nodes_fn` once during its initialization (`__init__`) with a set of template parameters. It does this to understand the network's connectivity and size, which is necessary for the `cotengra` pathfinder. The *actual values* in the tensors from this initial call are discarded; only the *structure* is used.

### The Internal Mechanism

Here's a step-by-step breakdown of what happens inside `DistributedContractor`:

1.  **Initialization (`__init__`)**:
    - You provide the `nodes_fn` and a set of `params`.
    - `DistributedContractor` calls `nodes_fn(params)` to get the tensor network structure.
    - It passes this structure to `cotengra`'s `ReusableHyperOptimizer`.
    - **Pathfinding**: `cotengra` then performs an exhaustive search for an efficient contraction path. A key part of this is **slicing**. If the largest intermediate tensor in the best path exceeds a memory limit (which you can control via `cotengra_options`), `cotengra` will "slice" one or more of the largest tensor edges. Slicing means the contraction is repeated for each possible value of the sliced indices, and the results are summed up. This trades a massive increase in computation for a drastic reduction in memory.
    - The final output of this step is a `ContractionTree`, a plan that details the sequence of pairwise contractions and the slicing strategy.
    - **Task Distribution**: The contractor then divides the list of slices evenly among the available JAX devices.

2.  **Execution (`.value()` or `.value_and_grad()`)**:
    - You call the method with a *new* set of `params`.
    - The contractor uses JAX's `pmap` to send the contraction plan and the new `params` to all devices.
    - **Parallel Execution**: Each device, in parallel:
        - Calls your `nodes_fn` with the new `params` to generate the tensors with their *current numerical values*.
        - Iterates through its assigned subset of slices.
        - For each slice, it performs the small, memory-efficient contraction as prescribed by the `ContractionTree`.
        - It sums up the results of all its assigned slices.
    - **Aggregation**: The final results from each device are postprocessed via `op` function which can provided in `value()` and `value_and_grad()` methods and summed up on the host to produce the total value. If `.value_and_grad()` was called, the gradients are also aggregated in the same way.

This design is powerful because the most expensive step—pathfinding—is done only once. All subsequent calls with different parameters reuse the same optimized path, leading to very fast execution, especially in iterative algorithms like VQE.

## Example 1: Calculating State Amplitudes

One fundamental task is to compute the amplitude of a specific basis state, $\langle s | \psi \rangle$, where $|s\rangle$ is a bitstring like $|0110\dots\rangle$ and $|\psi\rangle$ is the state produced by a quantum circuit.

### Defining the `nodes_fn`


In [2]:
N_QUBITS_AMP = 14
DEPTH_AMP = 7


def circuit_ansatz(n, d, params):
    """A standard hardware-efficient ansatz."""
    c = tc.Circuit(n)
    c.h(range(n))
    for i in range(d):
        for j in range(0, n - 1):
            c.rzz(j, j + 1, theta=params[j, i, 0])
        for j in range(n):
            c.rx(j, theta=params[j, i, 1])
        for j in range(n):
            c.ry(j, theta=params[j, i, 2])
    return c


def get_nodes_fn_amp(n, d):
    """
    This function returns another function that will be the input for DistributedContractor.
    The inner function takes a dictionary of parameters and returns the tensor for a single amplitude.
    """

    def nodes_fn(params):
        psi = circuit_ansatz(n, d, params["circuit"])
        # `amplitude_before` gives us the tensor network before final contraction
        return psi.amplitude_before(params["amplitude"])

    return nodes_fn


def get_binary_representation(i: int, N: int) -> jax.Array:
    """Helper function to convert an integer to its binary representation."""
    shifts = jnp.arange(N - 1, -1, -1)
    return ((i >> shifts) & 1).astype(jnp.int32)

### Initializing the `DistributedContractor`

In [3]:
nodes_fn_amp = get_nodes_fn_amp(N_QUBITS_AMP, DEPTH_AMP)

# We need some initial parameters to define the network structure
key = jax.random.PRNGKey(42)
params_circuit_amp = (
    jax.random.normal(key, shape=[N_QUBITS_AMP, DEPTH_AMP, 3], dtype=tc.rdtypestr) * 0.1
)
initial_params_amp = {
    "circuit": params_circuit_amp,
    "amplitude": get_binary_representation(0, N_QUBITS_AMP),
}

print("Initializing DistributedContractor for amplitude calculation...")
# cotengra_options allow fine-tuning of the pathfinding process.
# `target_size` in `slicing_reconf_opts` controls the memory size of each slice.
DC_amp = DistributedContractor(
    nodes_fn=nodes_fn_amp,
    params=initial_params_amp,
    cotengra_options={
        "slicing_reconf_opts": {"target_size": 2**14},
        "max_repeats": 64,
        "progbar": True,
        "minimize": "write",  # Optimizes for memory write operations
        "parallel": 4,
    },
)

Initializing DistributedContractor for amplitude calculation...


F=4.88 C=6.02 S=9.00 P=11.21: 100%|██████████| 64/64 [00:08<00:00,  7.70it/s] 


--- Contraction Path Info ---
Path found with 1 slices.
Arithmetic Intensity (higher is better): 4.94
flops (TFlops): 1.7143975128419697e-08
write (GB): 0.00011376291513442993
size (GB): 3.814697265625e-06
-----------------------------

Distributing across 4 devices. Each device will sequentially process up to 1 slices.





### Calculating Multiple Amplitudes


In [5]:
n_amp = 10
print("Starting amplitude loop...")
for i in range(n_amp):
    bs_vector = get_binary_representation(i, N_QUBITS_AMP)
    params = {"circuit": params_circuit_amp, "amplitude": bs_vector}

    t0 = time.time()
    amp = DC_amp.value(params)
    t1 = time.time()

    print(
        f"Bitstring: {bs_vector.tolist()} | "
        f"Amp (DC): {amp:.8f} | "
        f"Time: {t1 - t0:.4f} s"
    )

Starting amplitude loop...
Bitstring: [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0] | Amp (DC): 0.00353913+0.00129915j | Time: 0.0011 s
Bitstring: [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1] | Amp (DC): 0.00360568+0.00364827j | Time: 0.0007 s
Bitstring: [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0] | Amp (DC): 0.00032678+0.00287017j | Time: 0.0007 s
Bitstring: [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1] | Amp (DC): 0.00228170+0.00377162j | Time: 0.0006 s
Bitstring: [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0] | Amp (DC): 0.00194864+0.00337264j | Time: 0.0006 s
Bitstring: [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 1] | Amp (DC): 0.00013618+0.00505736j | Time: 0.0005 s
Bitstring: [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0] | Amp (DC): 0.00230806+0.00277415j | Time: 0.0006 s
Bitstring: [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1] | Amp (DC): 0.00485482+0.00290096j | Time: 0.0005 s
Bitstring: [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0] | Amp (DC): 0.00410894+0.00097546j | Time: 0.0005 s
Bitstring: [0, 0, 0

## Example 2: Variational Quantum Eigensolver (VQE)

`DistributedContractor` is especially powerful for variational algorithms, where we need to repeatedly compute expectation values and their gradients.

### Defining the `nodes_fn` for Expectation Value

In [6]:
N_QUBITS_VQE = 10
DEPTH_VQE = 4


def get_tfi_mpo(n):
    """Gets the MPO for the 1D Transverse-Field Ising model Hamiltonian."""
    import tensornetwork as tn

    Jx = np.ones(n - 1)
    Bz = -1.0 * np.ones(n)
    tn_mpo = tn.matrixproductstates.mpo.FiniteTFI(Jx, Bz, dtype=np.complex64)
    return tc.quantum.tn2qop(tn_mpo)


def get_nodes_fn_vqe(n, d, mpo):
    """
    The nodes_fn for VQE expectation value.
    It returns the list of tensors for <psi|H|psi>.
    """

    def nodes_fn(params):
        psi = circuit_ansatz(n, d, params).get_quvector()
        expression = psi.adjoint() @ mpo @ psi
        return expression.nodes

    return nodes_fn

### VQE Optimization Loop


In [8]:
tfi_mpo = get_tfi_mpo(N_QUBITS_VQE)
nodes_fn_vqe = get_nodes_fn_vqe(N_QUBITS_VQE, DEPTH_VQE, tfi_mpo)

# Initial parameters for VQE
key = jax.random.PRNGKey(42)
params_vqe = (
    jax.random.normal(key, shape=[N_QUBITS_VQE, DEPTH_VQE, 3], dtype=tc.rdtypestr) * 0.1
)

print("\nInitializing DistributedContractor for VQE...")
DC_vqe = DistributedContractor(
    nodes_fn=nodes_fn_vqe,
    params=params_vqe,
    cotengra_options={
        "slicing_reconf_opts": {
            "target_size": 2**8
        },  # Smaller target size for VQE network
        "max_repeats": 16,
        "progbar": True,
        "minimize": "write",
        "parallel": 4,
    },
)

# Setup Optax optimizer
optimizer = optax.adam(2e-2)
opt_state = optimizer.init(params_vqe)


@K.jit
def opt_update(params, opt_state, grads):
    updates, new_opt_state = optimizer.update(grads, opt_state, params)
    new_params = optax.apply_updates(params, updates)
    return new_params, new_opt_state


n_steps_vqe = 100
print("\nStarting VQE optimization loop...")
for i in range(n_steps_vqe):
    t0 = time.time()
    loss, grads = DC_vqe.value_and_grad(params_vqe)

    params_vqe, opt_state = opt_update(params_vqe, opt_state, grads)
    t1 = time.time()

    if (i + 1) % 10 == 0:
        print(f"Step {i+1:03d} | " f"Loss: {loss:.8f} | " f"Time: {t1 - t0:.4f} s")

print("\nOptimization finished.")


Initializing DistributedContractor for VQE...


F=5.61 C=6.81 S=7.58 P=11.10 $=16.00: 100%|██████████| 16/16 [00:08<00:00,  1.98it/s]



--- Contraction Path Info ---
Path found with 16 slices.
Arithmetic Intensity (higher is better): 4.30
flops (TFlops): 9.178620530292392e-08
write (GB): 4.3742358684539795e-05
size (GB): 1.430511474609375e-06
-----------------------------

Distributing across 4 devices. Each device will sequentially process up to 4 slices.

Starting VQE optimization loop...
Step 010 | Loss: -3.29106593 | Time: 0.0022 s
Step 020 | Loss: -8.78426552 | Time: 0.0021 s
Step 030 | Loss: -10.85906601 | Time: 0.0022 s
Step 040 | Loss: -11.47075844 | Time: 0.0020 s
Step 050 | Loss: -11.72393227 | Time: 0.0022 s
Step 060 | Loss: -11.91652107 | Time: 0.0021 s
Step 070 | Loss: -12.04074574 | Time: 0.0020 s
Step 080 | Loss: -12.11665630 | Time: 0.0024 s
Step 090 | Loss: -12.15851784 | Time: 0.0021 s
Step 100 | Loss: -12.17830086 | Time: 0.0022 s

Optimization finished.


## Conclusion

The `DistributedContractor` provides a powerful and streamlined interface for scaling up tensor network contractions to multiple devices. By abstracting away the complexities of pathfinding, slicing, and parallel execution, it allows researchers to focus on the physics of their problem while leveraging the full computational power of their hardware. This is particularly advantageous for simulating large quantum circuits and accelerating the convergence of variational quantum algorithms.