In [4]:
import logging
import math
import os
from functools import partial, reduce
from operator import matmul, mul, or_
from typing import (
    Any,
    Callable,
    Collection,
    Dict,
    Iterable,
    List,
    Optional,
    Sequence,
    Tuple,
    Union,
)

import numpy as np
from tensornetwork.network_components import AbstractNode, CopyNode, Edge, Node, connect
from tensornetwork.network_operations import (
    copy,
    get_subgraph_dangling,
    remove_node,
)

from tenpy.networks import MPO, MPS, Site
from tenpy.linalg import np_conserved as npc
from tenpy.linalg import LegCharge
import tensornetwork as tn
from quimb.tensor import MatrixProductOperator
import quimb.tensor as qtn

from tensorcircuit.cons import backend, contractor, dtypestr, npdtype, rdtypestr
from tensorcircuit.gates import Gate, num_to_tensor
from tensorcircuit.utils import arg_alias

Tensor = Any
Graph = Any

from tenpy.models.tf_ising import TFIChain
from tensorcircuit.quantum import QuOperator, quantum_constructor
import tensorcircuit as tc
import pytest
from pytest_lazyfixture import lazy_fixture as lf

import tensornetwork as tn
from functools import reduce
from operator import mul

import tensornetwork as tn
from functools import reduce
from operator import mul
import numpy as np

"""
Demonstration of TeNPy-DMRG and TensorCircuit integration
1. Compute ground state (MPS) of 1D Heisenberg model using TeNPy
2. Convert MPS to TensorCircuit's QuOperator
3. Initialize MPSCircuit with converted state and verify results
"""

import numpy as np
from typing import Union, Any
from tenpy.networks.mps import MPS
from tenpy.networks.mpo import MPO
from tenpy.models.xxz_chain import XXZChain
from tenpy.algorithms import dmrg
import tensornetwork as tn

import tensorcircuit as tc

Node = tn.Node
Edge = tn.Edge
connect = tn.connect

QuOperator = tc.quantum.QuOperator
quantum_constructor = tc.quantum.quantum_constructor


def tenpy2qop(tenpy_obj: Union["MPS", "MPO"]) -> QuOperator:
    """
    Converts a TeNPy MPO or MPS to a TensorCircuit QuOperator.
    This definitive version correctly handles axis ordering and boundary
    conditions to be compatible with `eval_matrix`.

    :param tenpy_obj: A MPO or MPS object from the TeNPy package.
    :type tenpy_obj: Union[tenpy.networks.mpo.MPO, tenpy.networks.mps.MPS]
    :return: The corresponding state or operator as a QuOperator.
    :rtype: QuOperator
    """
    is_mpo = hasattr(tenpy_obj, "_W")
    tenpy_tensors = tenpy_obj._W if is_mpo else tenpy_obj._B
    nwires = len(tenpy_tensors)
    if nwires == 0:
        return quantum_constructor([], [], [])

    nodes = []
    if is_mpo:
        vr_label, vl_label = "wR", "wL"
        original_tensors = [W.to_ndarray() for W in tenpy_tensors]
        modified_tensors = []

        for i, (tensor, tenpy_t) in enumerate(zip(original_tensors, tenpy_tensors)):
            labels = tenpy_t._labels
            if nwires == 1:
                tensor = np.take(tensor, [0], axis=labels.index(vl_label))
                tensor = np.take(tensor, [-1], axis=labels.index(vr_label))
            else:
                if i == 0:
                    tensor = np.take(tensor, [0], axis=labels.index(vl_label))
                elif i == nwires - 1:
                    tensor = np.take(tensor, [-1], axis=labels.index(vr_label))
            modified_tensors.append(tensor)

        for i, t in enumerate(modified_tensors):
            if t.ndim == 4:
                t = t.transpose((0, 2, 3, 1))
            nodes.append(
                Node(t, name=f"tensor_{i}", axis_names=["wL", "p", "p*", "wR"])
            )

        for i in range(nwires - 1):
            connect(nodes[i]["wR"], nodes[i + 1]["wL"])

        out_edges = [node["p*"] for node in nodes]
        in_edges = [node["p"] for node in nodes]
        ignore_edges = [nodes[0]["wL"], nodes[-1]["wR"]]
    else:
        nodes = [Node(W.to_ndarray()) for W in tenpy_tensors]
        if nwires > 1:
            for i in range(nwires - 1):
                nodes[i][2] ^ nodes[i + 1][0]
        out_edges = [n[1] for n in nodes]
        in_edges = []
        ignore_edges = [nodes[0][0], nodes[-1][2]]

    qop = quantum_constructor(out_edges, in_edges, [], ignore_edges)

    return qop


def qop2tenpy(qop: QuOperator) -> Any:
    """
    Convert TensorCircuit QuOperator to MPO or MPS from TeNPy.

    :param qop: The corresponding state/operator as a QuOperator.
    :return: MPO or MPS object from the TeNPy package.
    :rtype: Union[tenpy.networks.mpo.MPO, tenpy.networks.mps.MPS]
    """
    is_mps = len(qop.in_edges) == 0
    nwires = len(qop.out_edges) if is_mps else len(qop.in_edges)

    # Node sorting
    endpoint_nodes = {edge.node1 for edge in qop.ignore_edges if edge.node1}
    physical_edges = set(qop.in_edges + qop.out_edges)
    if len(endpoint_nodes) < 2 and len(qop.ref_nodes) > 1:
        inferred_endpoints = {
            node
            for node in qop.ref_nodes
            if sum(1 for edge in node.edges if edge not in physical_edges) == 1
        }
        if len(inferred_endpoints) == 2:
            endpoint_nodes = inferred_endpoints

    # to correctly sort nodes
    sorted_nodes: list[Node] = []
    if endpoint_nodes:
        current = next(iter(endpoint_nodes))
        while current and len(sorted_nodes) < nwires:
            sorted_nodes.append(current)
            current = next(
                (
                    e.node2 if e.node1 is current else e.node1
                    for e in current.edges
                    if not e.is_dangling()
                    and e not in physical_edges
                    and (e.node2 if e.node1 is current else e.node1) not in sorted_nodes
                ),
                None,
            )

    if not sorted_nodes:
        sorted_nodes = qop.nodes

    if len(sorted_nodes) > 0 and len(qop.ignore_edges) > 0:
        if sorted_nodes[0] is not qop.ignore_edges[0].node1:
            sorted_nodes = sorted_nodes[::-1]

    physical_dim = qop.in_edges[0].dimension
    sites = [Site(LegCharge.from_trivial(physical_dim), "q") for _ in range(nwires)]

    # MPS Conversion
    if is_mps:
        tensors = []
        for i, node in enumerate(sorted_nodes):
            tensor = np.asarray(node.tensor)
            if i == 0 and tensor.shape == (2, 2, 2):
                tensor = tensor[0:1, :, :]
            elif i == len(sorted_nodes) - 1 and tensor.shape == (2, 2, 2):
                tensor = tensor[:, :, 0:1]
            tensors.append(
                npc.Array.from_ndarray(
                    tensor,
                    legcharges=[LegCharge.from_trivial(s) for s in tensor.shape],
                    labels=["vL", "p", "vR"],
                )
            )

        SVs = (
            [np.ones([1])]
            + [np.ones(tensors[i].get_leg("vR").ind_len) for i in range(nwires - 1)]
            + [np.ones([1])]
        )
        return MPS(sites, tensors, SVs, bc="finite")

    # MPO Conversion
    raw_tensors = [np.asarray(node.tensor) for node in sorted_nodes]

    chi = 1
    if nwires > 1:
        chi = raw_tensors[0].shape[3]

    IdL = 0
    IdR = chi - 1 if chi > 1 else 0

    reconstructed_tensors = []

    tensors_to_process = [t.copy() for t in raw_tensors]

    if nwires > 0:
        if tensors_to_process[0].shape[0] < chi:
            t_left = tensors_to_process[0]
            new_shape = (chi,) + t_left.shape[1:]
            padded_left = np.zeros(new_shape, dtype=t_left.dtype)
            padded_left[IdL, ...] = t_left[0, ...]
            tensors_to_process[0] = padded_left

        if tensors_to_process[-1].shape[3] < chi:
            t_right = tensors_to_process[-1]
            new_shape = t_right.shape[:3] + (chi,)
            padded_right = np.zeros(new_shape, dtype=t_right.dtype)
            padded_right[..., IdR] = t_right[..., 0]
            tensors_to_process[-1] = padded_right

    reconstructed_tensors = tensors_to_process

    tenpy_Ws = []
    for tensor in reconstructed_tensors:
        labels = ["wL", "wR", "p", "p*"]
        tensor = np.transpose(tensor, (0, 3, 1, 2))
        tenpy_Ws.append(
            npc.Array.from_ndarray(
                tensor,
                legcharges=[LegCharge.from_trivial(s) for s in tensor.shape],
                labels=labels,
            )
        )

    return MPO(sites, tenpy_Ws, bc="finite", IdL=IdL, IdR=IdR)

In [None]:
@pytest.mark.parametrize("backend", [lf("npb"), lf("tfb"), lf("jaxb")])
def test_tenpy_roundtrip(backend):
    try:
        import tensornetwork as tn
        from tenpy.networks.mps import MPS
        from tenpy.networks.mpo import MPO
        from tenpy.networks.site import Site
        from tenpy.linalg.charges import LegCharge
        from tenpy.linalg import np_conserved as npc
    except ImportError:
        pytest.skip("TeNPy or TensorNetwork is not installed")

    nwires_mpo = 3
    chi_mpo = 4
    phys_dim = 2
    sites = [Site(LegCharge.from_trivial(phys_dim), "q") for _ in range(nwires_mpo)]
    t_left = np.random.rand(1, chi_mpo, phys_dim, phys_dim)
    Ws = [
        npc.Array.from_ndarray(
            t_left,
            labels=["wL", "wR", "p", "p*"],
            legcharges=[LegCharge.from_trivial(s) for s in t_left.shape],
        )
    ]
    for _ in range(nwires_mpo - 2):
        t_bulk = np.random.rand(chi_mpo, chi_mpo, phys_dim, phys_dim)
        Ws.append(
            npc.Array.from_ndarray(
                t_bulk,
                labels=["wL", "wR", "p", "p*"],
                legcharges=[LegCharge.from_trivial(s) for s in t_bulk.shape],
            )
        )
    t_right = np.random.rand(chi_mpo, 1, phys_dim, phys_dim)
    Ws.append(
        npc.Array.from_ndarray(
            t_right,
            labels=["wL", "wR", "p", "p*"],
            legcharges=[LegCharge.from_trivial(s) for s in t_right.shape],
        )
    )
    mpo_original = MPO(sites, Ws, IdL=0, IdR=chi_mpo - 1)

    qop_mpo = tc.quantum.tenpy2qop(mpo_original)
    mpo_roundtrip = tc.quantum.qop2tenpy(qop_mpo)

    try:
        mat_original = mpo_original.to_matrix()
        mat_roundtrip = mpo_roundtrip.to_matrix()
    except AttributeError:
        # Inlined logic for mpo_original
        contraction_nodes_orig = []
        canonical_order_orig = ["wL", "wR", "p", "p*"]
        for i in range(mpo_original.L):
            W_tenpy_array = mpo_original.get_W(i)
            W_transposed = W_tenpy_array.itranspose(canonical_order_orig)
            contraction_nodes_orig.append(tn.Node(W_transposed.to_ndarray()))
        for i in range(mpo_original.L - 1):
            contraction_nodes_orig[i][1] ^ contraction_nodes_orig[i + 1][0]
        chi_left_orig = contraction_nodes_orig[0].shape[0]
        chi_right_orig = contraction_nodes_orig[-1].shape[1]
        left_bc_vec_orig = np.zeros(chi_left_orig)
        left_idx_orig = 0 if chi_left_orig == 1 else mpo_original.IdL
        left_bc_vec_orig[left_idx_orig] = 1.0
        right_bc_vec_orig = np.zeros(chi_right_orig)
        right_idx_orig = 0 if chi_right_orig == 1 else mpo_original.IdR
        right_bc_vec_orig[right_idx_orig] = 1.0
        left_node_orig = tn.Node(left_bc_vec_orig)
        right_node_orig = tn.Node(right_bc_vec_orig)
        contraction_nodes_orig[0][0] ^ left_node_orig[0]
        contraction_nodes_orig[-1][1] ^ right_node_orig[0]
        output_edges_orig = [node[3] for node in contraction_nodes_orig]
        input_edges_orig = [node[2] for node in contraction_nodes_orig]
        result_node_orig = tn.contractors.auto(
            contraction_nodes_orig + [left_node_orig, right_node_orig],
            output_edge_order=output_edges_orig + input_edges_orig,
        )
        D_orig = np.prod(mpo_original.dim)
        mat_original = tc.backend.reshape(result_node_orig.tensor, (D_orig, D_orig))

        # Inlined logic for mpo_roundtrip
        contraction_nodes_rt = []
        canonical_order_rt = ["wL", "wR", "p", "p*"]
        for i in range(mpo_roundtrip.L):
            W_tenpy_array_rt = mpo_roundtrip.get_W(i)
            W_transposed_rt = W_tenpy_array_rt.itranspose(canonical_order_rt)
            contraction_nodes_rt.append(tn.Node(W_transposed_rt.to_ndarray()))
        for i in range(mpo_roundtrip.L - 1):
            contraction_nodes_rt[i][1] ^ contraction_nodes_rt[i + 1][0]
        chi_left_rt = contraction_nodes_rt[0].shape[0]
        chi_right_rt = contraction_nodes_rt[-1].shape[1]
        left_bc_vec_rt = np.zeros(chi_left_rt)
        left_idx_rt = 0 if chi_left_rt == 1 else mpo_roundtrip.IdL
        left_bc_vec_rt[left_idx_rt] = 1.0
        right_bc_vec_rt = np.zeros(chi_right_rt)
        right_idx_rt = 0 if chi_right_rt == 1 else mpo_roundtrip.IdR
        right_bc_vec_rt[right_idx_rt] = 1.0
        left_node_rt = tn.Node(left_bc_vec_rt)
        right_node_rt = tn.Node(right_bc_vec_rt)
        contraction_nodes_rt[0][0] ^ left_node_rt[0]
        contraction_nodes_rt[-1][1] ^ right_node_rt[0]
        output_edges_rt = [node[3] for node in contraction_nodes_rt]
        input_edges_rt = [node[2] for node in contraction_nodes_rt]
        result_node_rt = tn.contractors.auto(
            contraction_nodes_rt + [left_node_rt, right_node_rt],
            output_edge_order=output_edges_rt + input_edges_rt,
        )
        D_rt = np.prod(mpo_roundtrip.dim)
        mat_roundtrip = tc.backend.reshape(result_node_rt.tensor, (D_rt, D_rt))

    np.testing.assert_allclose(mat_original, mat_roundtrip, atol=1e-5)

    # MPS roundtrip
    nwires_mps = 4
    chi_mps = 5
    sites_mps = [Site(LegCharge.from_trivial(phys_dim), "q") for _ in range(nwires_mps)]
    b_left = np.random.rand(1, phys_dim, chi_mps)
    Bs = [
        npc.Array.from_ndarray(
            b_left,
            labels=["vL", "p", "vR"],
            legcharges=[LegCharge.from_trivial(s) for s in b_left.shape],
        )
    ]
    for _ in range(nwires_mps - 2):
        b_bulk = np.random.rand(chi_mps, phys_dim, chi_mps)
        Bs.append(
            npc.Array.from_ndarray(
                b_bulk,
                labels=["vL", "p", "vR"],
                legcharges=[LegCharge.from_trivial(s) for s in b_bulk.shape],
            )
        )
    b_right = np.random.rand(chi_mps, phys_dim, 1)
    Bs.append(
        npc.Array.from_ndarray(
            b_right,
            labels=["vL", "p", "vR"],
            legcharges=[LegCharge.from_trivial(s) for s in b_right.shape],
        )
    )
    SVs = [np.ones([1])]
    for B in Bs[:-1]:
        sv_dim = B.get_leg("vR").ind_len
        SVs.append(np.ones(sv_dim))
    SVs.append(np.ones([1]))
    mps_original = MPS(sites_mps, Bs, SVs)

    qop_mps = tc.quantum.tenpy2qop(mps_original)
    mps_roundtrip = tc.quantum.qop2tenpy(qop_mps)

    try:
        vec_original = mps_original.to_ndarray()
        vec_roundtrip = mps_roundtrip.to_ndarray()
    except AttributeError:
        full_theta_orig = mps_original.get_theta(0, mps_original.L)
        state_tensor_orig = full_theta_orig.to_ndarray()
        vec_original = state_tensor_orig.reshape(-1)

        full_theta_rt = mps_roundtrip.get_theta(0, mps_roundtrip.L)
        state_tensor_rt = full_theta_rt.to_ndarray()
        vec_roundtrip = state_tensor_rt.reshape(-1)

    np.testing.assert_allclose(vec_original, vec_roundtrip, atol=1e-5)


test_tenpy_roundtrip(backend)