NOTE: I thought Google's tensor network library might help with some of what I'm doing. My conclusion is that it is a promising library, but honestly is not very useful unless it starts caching large contractions. 

In [None]:
%load_ext autoreload
%autoreload 2

In [336]:
import tensornetwork as tn
import numpy as np
from pprint import pprint
from tebd import tebd
from misc import mps_2form, mps_overlap, mpo_on_mpo
from state_approximation import mps2mpo, mpo2mps, multiple_diagonal_expansions,\
                                diagonal_expansion, contract_diagonal_expansion,\
                                contract_series_diagonal_expansions, mpo_on_mps,\
                                entanglement_entropy
from moses_simple import moses_move as moses_move_simple
import matplotlib.pyplot as plt
from disentanglers import disentangle_S2
import scipy
from scipy.stats import unitary_group

In [6]:
def process_network(network, labels=None):
    """ Given a tensor network -- either an MPO or an MPS -- return a list of tn.Node
    with correct connections. 
    Parameters
    ----------
    network: List
        list of np.Array with format either pSN or WESN. 
    labels : List
        A list of str to add as labels to each node.
    """
    if type(network[0]) == tn.Node:
        return network
    tn_network = np.array([tn.Node(i) for i in network])
    L = len(network)
    if labels:
        for i in range(L):
            tn_network[i].name = str(labels[i])
    if tn_network[0].get_rank() == 3: #MPS
        for i in range(L-1):
            tn_network[i][-1] ^ tn_network[i+1][1]
    if tn_network[0].get_rank() == 4: #MPO
        for i in range(L-1):
            tn_network[i][-1] ^ tn_network[i+1][2]
    else:
        raise ValueError("Unrecognized tensor network.")
    return tn_network

def _mpo_on_mpo_straight(mpoL, mpoR):
    """ mpo on mpo, assumes inputs are lists of tensor networks with
    vertical contractions already performed """
    mpo_out = []
    L = len(mpoL)
    for i in range(L):
        mpoL[i][1] ^ mpoR[i][0]
        tensor_out = tn.contract_between(mpoL[i], mpoR[i])
        tn.flatten_edges([tensor_out[1], tensor_out[4]])
        tn.flatten_edges([tensor_out[1], tensor_out[3]])
        mpo_out.append(tensor_out)
    return mpo_out

def _mpo_on_mpo_shifted(mpoL, mpoR):
    """ mpo on mpo with the shifted protocol. Does this by contracting the
    entire tensor then using a series of qr decompositions to return to
    isometric form. 
    """
    L = len(mpoL)
    mpoL[0][1] ^ mpoR[0][2]
    for i in range(L-1):
        mpoL[i+1][1] ^ mpoR[i][0]
    mpoL[L-1][3] ^ mpoR[L-1][0]
    
    edge_list = []
    for i in range(L):
        edge_list.extend(ordered_get_all_dangling([mpoL[i], mpoR[i]]))
        
    output = tn.contractors.auto(tn.reachable(mpoL[0]), edge_list)
    out = []
    node_q, output = tn.split_node_qr(output, left_edges=output[0:3], right_edges=output[3:])
    out.append(node_q.reorder_axes([0,2,1,3]))
    for i in range(L-2):
        node_q, output = tn.split_node_qr(output, left_edges=output[:3], right_edges=output[3:])
        out.append(node_q.reorder_axes([1,2,0,3]))
    out.append(output.reorder_axes([1,2,0,3]))
    return out

def _mpo_on_mpo_shifted_efficient(mpoL, mpoR):
    """ mpo_on_mpo with the efficient shifted protocol. This should be your default.
    Only needs to do qr decompositions on the first and last tensor.
    """
    L = len(mpoL)
    mpo_out = []
    for i in range(2, L-1):
        mpoL[i][1] ^ mpoR[i-1][0]
        mpo_out.append(tn.contract_between(mpoL[i], mpoR[i-1]))
    for i in range(2,L-1):
        tn.flatten_edges([mpo_out[i-2][1], mpo_out[i-2][4]])
        tn.flatten_edges([mpo_out[i-2][1], mpo_out[i-2][3]])
    
    mpoL[1][1] ^ mpoR[0][0]
    mpoL[0][1] ^ mpoR[0][2]
    
    first_tensor = tn.contract_between(mpoL[0], mpoL[1])
    first_tensor = tn.contract_between(first_tensor, mpoR[0])
    first_tensor, second_tensor = tn.split_node_qr(first_tensor, 
                                                   left_edges=[first_tensor[0], first_tensor[1]],
                                                   right_edges=[first_tensor[i] for i in [2,4,3,5]]
                                                )
    first_shape = first_tensor.shape
    first_tensor = tn.Node(first_tensor.tensor.reshape((*first_shape, 1)))
    first_tensor = tn.transpose(first_tensor, [0,3,1,2])
    
    tn.flatten_edges([second_tensor[3], second_tensor[4]])
    second_tensor = tn.transpose(second_tensor, [1,2,0,3])
    mpo_out = [first_tensor, second_tensor] + mpo_out
    

    
    mpoL[L-1][3] ^ mpoR[L-1][0]
    mpoL[L-1][1] ^ mpoR[L-2][0]
    first_tensor = tn.contract_between(mpoL[L-1], mpoR[L-2])
    first_tensor = tn.contract_between(first_tensor, mpoR[L-1])
    first_tensor, second_tensor = tn.split_node_qr(first_tensor,
                                  left_edges=[first_tensor[i] for i in range(4)],
                                  right_edges=[first_tensor[i] for i in [4,5]])
    second_tensor = tn.flatten_edges([second_tensor[3], second_tensor[4]])
    second_tensor = tn.transpose(second_tensor, [1,2,0,3])
    mpo_out.extend([first_tensor, second_tensor])
    return(mpo_out)

def mpo_on_mpo_generalized(mpoL, mpoR, mode='shifted'):
    """ Contracts two mpos with shift for series of quantum gates. Really contracts the whole 
    tensor then splits it using a series of QRs, so may end up being quite expensive... but
    is probably the correct way to do it.
    Parameters
    ----------
    mpoL:
        Left mpo
    mpoR:
        Right mpo with leg format 
    mode:
        Contraction method. The options are straight, shifted, and shifted_efficient. The
        first is a vanilla mpo on mpo, the second is the fancy shifting we're doing, but
        we contract the whole network and reconstruct the mpo. The third only does QRs
        on the first and the last few tensors.
    """
    mpoL = process_network(mpoL.copy())
    mpoR = process_network(mpoR.copy())
    
    if mode == 'shifted':
        return _mpo_on_mpo_shifted(mpoL, mpoR)
    elif mode == 'shifted_efficient':
        return _mpo_on_mpo_shifted_efficient(mpoL, mpoR)
    elif mode == 'straight':
        return _mpo_on_mpo_straight(mpoL, mpoR)
    else:
        raise ValueError("Invalid mode")
    
def series_contraction(As, Lambda):
    output = As[0]
    for A in As[1:]:
        output = mpo_on_mpo(output, A)
    return mpo_on_mpo(output, Lambda)


def remove_dangling_legs(mps):
    """ Removes the dangling legs on an mps. """
    d1, chiL1, chiR1 = mps[0].shape
    d2, chiL2, chiR2 = mps[-1].shape
    assert chiL1 == chiR2 == 1
    mps[0] = mps[0].reshape((d1, chiR1))
    mps[-1] = mps[-1].reshape((d2, chiL2))
    return mps


In [269]:
def _sweep_disentangle_single_pass(Psi):
    L = len(Psi)
    Us = []
    for i in range(L-1):
        theta = np.tensordot(Psi[i], Psi[i+1], [2,1]).transpose([1,0,2,3])
        theta, U, Ss = disentangle_S2(theta, max_iter=1, eps=1.e-16)
        chiL, d1, d2, chiR = theta.shape
        q, r = np.linalg.qr(theta.reshape(chiL*d1, chiR*d2))
        Psi[i] = q.reshape((chiL, d1, -1)).transpose([1,0,2])
        Psi[i+1] = r.reshape((-1, d2, chiR)).transpose([1,0,2])
        Us.append(U)
    return Psi, Us

def sweep_disentangle_one_side(Lambda):
    """ 
    Performs a back and forth sweep of disentangling on an MPS Lambda.
    Parameters
    ----------
    Lambda : list of np.Array
        Physical wavefunction
    Returns
    -------
    Psi : List of np.Array 
        Disentanglined wavefunction
    U1, U2 : np.Arrays
        List of unitaries. U1 is the list of unitaries on the forward sweep, U2
        is the list on the back sweep.
        NOTE: U2 assumes that we're on the back sweep...
    """
    Psi = Lambda.copy()
    if Psi[0].ndim == 4:
        Psi = mpo2mps(Lambda)
    Psi, U1 = _sweep_disentangle_single_pass(Psi)
    Psi = [psi.transpose([0,2,1]) for psi in Psi[::-1]]
    Psi, U2 = _sweep_disentangle_single_pass(Psi)
    Psi = [psi.transpose([0,2,1]) for psi in Psi[::-1]]
    return mps2mpo(Psi), [U1, U2]

def sweep_and_disentangle(A0, Lambda):
    """
    Given A0, Lambda s.t. Psi = A0.Lambda, performs a single sweep back and forth
    to disentangle Lambda further.
    Parameters
    ----------
    A0 : list of np.Array
    Lambda : list of np.Array
    Returns
    -------
    A0, Lambda : Disentangled versions
    """
    Lambda = [psi.transpose([0,1,3,2]) for psi in Lambda[::-1]]
    Lambda, [U1, U2] = sweep_disentangle_one_side(Lambda)
    Lambda = [psi.transpose([0,1,3,2]) for psi in Lambda[::-1]]


    U1 = [U.conj() for U in U1[::-1]]
    U2 = [U.conj() for U in U2[::-1]]
    #U1 = None
    #U2 = None
    A0 = split_and_apply_unitary(A0, U1, form='A')
    A0 = split_and_apply_unitary(A0, U2, form='B')
    return A0, Lambda

def get_eye(d1, d2):
    return np.kron(np.eye(d1), np.eye(d2)).reshape([d1,d2,d1,d2])

def split_and_apply_unitary(Psi, Us, form='B'):
    """ Given a wavefunction Psi and a list of two site unitaries U, sweeps along
    the wavefunction applying all the Us. This is designed to act on A0 in the 
    shifted regime -- so it does not act on the first  tensor, and acts on two
    legs of the last tensor. Starting from A form reverses this, but the broken
    symmetry due to the shift means we can't just reverse the wavefunction.
    
    TODO: refactor this
    
    Parameters
    ----------
    Psi : list of np.Array
        The wavefunction
    Us : list of np.Array
        List of L-1 two site unitaries
    form : str
        Starting form, either A (arrows pointing up) or B (arrows pointing down)
    
    Ensure that the zeroth leg is acted on by the unitary.
    """
    L = len(Psi)
    assert form in ['A', 'B']
           
    if form == 'A':
        Psi = [psi.transpose([0,1,3,2]) for psi in Psi[::-1]]
        if Us is None:
            Us = []
            Us.append(get_eye(Psi[0].shape[2], Psi[0].shape[1]))
            for i in range(0, L-2):
                shape = (Psi[i].shape[0], Psi[i+1].shape[0])
                Us.append(get_eye(*shape))
                
        Psi[0] = np.tensordot(Psi[0], Us[0], [[1,3],[2,3]]).transpose([0,2,1,3])
        
        for i in range(0, L-2):
            psi = np.tensordot(Psi[i], Psi[i+1], [3, 2])
            W1,E1,S1,W2,E2,N2 = psi.shape
            

            
            psi = np.tensordot(psi, Us[i+1], [[1,4],[2,3]]).transpose([0,4,1,2,5,3])
                
            psi = psi.reshape(W1*E1*S1, W2*E2*N2)
            u, s, v = np.linalg.svd(psi, full_matrices=False)

            s = s[s > 1.e-10]
            chi_max = len(s)
            q = u[:,:chi_max]
            v = v[:chi_max,:]
            r = (np.diag(s) @ v)
            Psi[i] = q.reshape(W1,E1,S1,chi_max)
            Psi[i+1] = r.reshape(chi_max,W2,E2,N2).transpose([1,2,0,3])
        Psi = [psi.transpose([0,1,3,2]) for psi in Psi[::-1]]
        return Psi
    
    else:
        if Us is None:
            Us = []
            for i in range(1, L-1):
                shape = (Psi[i].shape[1], Psi[i+1].shape[1])
                Us.append(get_eye(*shape))
            Us.append(get_eye(Psi[L-1].shape[1], Psi[L-1].shape[3]))
        for i in range(0,L-1):
            psi = np.tensordot(Psi[i], Psi[i+1], [3, 2])
            W1,E1,S1,W2,E2,N2 = psi.shape
            
            if i > 0:
                psi = np.tensordot(psi, Us[i-1], [[1,4],[2,3]]).transpose([0,4,1,2,5,3])
            
            psi = psi.reshape(W1*E1*S1, W2*E2*N2)
            u, s, v = np.linalg.svd(psi, full_matrices=False)

            s = s[s > 1.e-10]
            chi_max = len(s)
            q = u[:,:chi_max]
            v = v[:chi_max,:]
            r = (np.diag(s) @ v)
            Psi[i] = q.reshape(W1,E1,S1,chi_max)
            Psi[i+1] = r.reshape(chi_max,W2,E2,N2).transpose([1,2,0,3])

        Psi[L-1] = np.tensordot(Psi[L-1], Us[-1], [[1,3],[2,3]]).transpose([0,2,1,3])
        return Psi

In [662]:
tebd_state, _, _ = tebd(10, 1.5, 0.1)
Psi = mps2mpo(tebd_state.copy())
Lambda = Psi.copy()
As, Lambda, Ss, Lambdas = multiple_diagonal_expansions(Psi,200)

In [663]:
#out = contract_diagonal_expansion(As[0], Lambda)
out = contract_series_diagonal_expansions(As, Lambda)
np.linalg.norm(mps_overlap(Psi, out)), mps_overlap(Psi, out)

(1.0000000000006972, 1.0000000000006972)

In [237]:
def contract_mpo(mpo):
    L = len(mpo)
    psi = mpo[0]
    for i in range(1, L):
        psi = np.tensordot(psi, mpo[i], [-1,2])
    return psi
        

In [335]:
As, Lambda, Ss, fidelity, cp, Lambdas = multiple_diagonal_expansions(Psi,200)
print(entanglement_entropy(Lambda))
for i in range(2000):
    Lambda, _ = sweep_disentangle_one_side(Lambda)
print(entanglement_entropy(Lambda))

0.047095146989858054
1.384817438467499e-05


In [390]:
As, Lambda, Ss, fidelity, cp, Lambdas = multiple_diagonal_expansions(Psi,1)
A = As[-1]
out = contract_diagonal_expansion(A, Lambda)
print(mps_overlap(Psi, out))
#print(entanglement_entropy(Lambda))
Lambda = [psi.transpose([0,1,3,2]) for psi in Lambda[::-1]]
A = [a.transpose([0,1,3,2]) for a in A[::-1]]
#Lambda = mpo2mps(mps_2form(Lambda, 'B'))
L = len(Lambda)
#print(L)
Us = []
for i in range(3):
    psi = np.tensordot(Lambda[i], Lambda[i+1], [3,2])
    W1,E1,S1,W2,E2,N2 = psi.shape
    #theta, U, _ = disentangle_S2(theta)
    U = unitary_group.rvs(4).reshape([2]*4)
    
    #theta = psi.reshape(S1*E1, W1, W2, E2*N2)
    theta = psi
    if i % 2 == 0:
        theta = np.tensordot(psi, U, [[0,3],[3,2]]).transpose([4,0,1,5,2,3])
        #theta = np.tensordot(U, theta, [[2,3],[1,2]]).transpose([2,0,1,3])

    psi = theta.reshape(S1*E1*W1, E2*N2*W2)
    Us.append(U)
    u, s, v = np.linalg.svd(psi, full_matrices=False)


    s = s[s > 1.e-10]
    chi_max = len(s)
    psi = u[:,:chi_max]
    v = v[:chi_max,:]
    v = (np.diag(s) @ v)
    Lambda[i] = psi.reshape(W1,E1,S1,chi_max)
    Lambda[i+1] = v.reshape(chi_max,W2,E2,N2).transpose([1,2,0,3])
    if i == 0:
        A[0] = np.tensordot(A[0], U.conj(), [[2,1],[2,3]]).transpose([0,3,2,1])
    else:
        psi = np.tensordot(A[i-1], A[i], [3,2])
        W1,E1,S1,W2,E2,N2 = psi.shape
        theta = psi.reshape(S1*W1, E1, E2, W2*N2)


        #theta, U, _ = disentangle_S2(theta)
        if i % 2 == 0:
            theta = np.tensordot(U.conj(), theta, [[3,2],[1,2]]).transpose([2,0,1,3])

        psi = theta.reshape(S1*E1*W1, E2*N2*W2)
        u, s, v = np.linalg.svd(psi, full_matrices=False)


        s = s[s > 1.e-10]
        chi_max = len(s)
        psi = u[:,:chi_max]
        v = v[:chi_max,:]
        v = (np.diag(s) @ v)
        A[i-1] = psi.reshape(W1,E1,S1,chi_max)
        A[i] = v.reshape(chi_max,W2,E2,N2).transpose([1,2,0,3])

        
A = [psi.transpose([0,1,3,2]) for psi in A[::-1]]        
Lambda = [psi.transpose([0,1,3,2]) for psi in Lambda[::-1]]
#A = split_and_apply_unitary(A, Us)
out = contract_diagonal_expansion(A, Lambda)
#print(entanglement_entropy(Lambda))
mps_overlap(Psi, out)

0.9518665013167418


(-0.11306510428076419-0.2617473053128689j)

# ===============

In [448]:
import numpy as np
from copy import deepcopy

a_shape = [(2, 1, 1, 2), (2, 2, 2, 2)]
lambda_shape = [(2, 1, 1, 2), (2, 1, 2, 4)]
A1 = [np.random.rand(*shape) for shape in a_shape]
Lambda = [np.random.rand(*shape) for shape in lambda_shape]
A2 = deepcopy(A1)
Lambda2 = deepcopy(Lambda)
def process_mpos(mpo):
    tn_mpo = [tn.Node(i) for i in mpo]
    for i in range(len(mpo)-1):
        tn_mpo[i][-1] ^ tn_mpo[i+1][2]
    return tn_mpo
import sys
def mpo_on_mpo(mpoL, mpoR):
    L = len(mpoL)
    mpoL[0][1] ^ mpoR[0][2]
    for i in range(L-1):
        mpoL[i+1][1] ^ mpoR[i][0]
    mpoL[L-1][3] ^ mpoR[L-1][0]
    
    edge_list = []
    for i in range(L):
        edge_list.extend(ordered_get_all_dangling([mpoL[i], mpoR[i]]))
    output = tn.contractors.auto(tn.reachable(mpoL[0]), output_edge_order=edge_list)

    return output

def ordered_get_all_dangling(nodes):
    edges = []
    for node in nodes:
        for edge in node.edges:
            if edge.is_dangling():
                edges.append(edge)
    return edges

In [90]:
def setup_mpo_network(lenv, renv, mpo):
    all_mpos = [[],[],[]]
    dl =  lenv[0].shape[0]
    L = len(mpo)
    for i in range(L):
        all_mpos[0].append(tn.Node(lenv[i], name=f"lenv_{i}"))
        all_mpos[1].append(tn.Node(mpo[i], name=f"mpo_{i}"))
        all_mpos[2].append(tn.Node(renv[i], name=f"renv_{i}"))
    for i in range(L-1):
        for j in range(3):
            all_mpos[j][i][3] ^ all_mpos[j][i+1][2]
            
    for i in range(1, L):
        all_mpos[0][i][1] ^ all_mpos[1][i-1][0]
        all_mpos[1][i][1] ^ all_mpos[2][i-1][0]
        
    all_mpos[0][L-1][3] ^ all_mpos[1][L-1][0]
    all_mpos[1][L-1][3] ^ all_mpos[2][L-1][0]
    
    all_mpos[0][0][1] ^ all_mpos[1][0][2]
    all_mpos[1][0][1] ^ all_mpos[2][0][2]
    
    return(all_mpos)

def setup_environment(all_mpos, i, mps_left, mps_right):
    """ Removes node i, contracts """
    L = len(mps_left)
    mps_L, all_mpos, mps_R = exp_val_mpo(mps_left, mps_right, all_mpos)
    
    #edge_order = []
    #edge_order.extend([all_mpos[0][i][0] for i in range(L)])
    #edge_order.extend([all_mpos[2][i][1] for i in range(L)])

    #edge_order.append(all_mpos[0][0][2])
    #edge_order.append(all_mpos[2][-1][3])
    
    output = tn.contractors.auto(tn.reachable(all_mpos[0][0]))
    return(output)

In [91]:
def contract_mps(mps):
    t = mps[0]
    d = mps[0].shape[0]
    for i in range(1, len(mps)):
        t = np.tensordot(t, mps[i], [-1, 1])
    return t.reshape([d]*len(mps))

def tensor_overlap(t1, t2):
    assert t1.shape == t2.shape
    return np.tensordot(t1, t2, [range(t1.ndim), range(t2.ndim)])