In [1]:
import jax
import jax.numpy as jnp
from jax.flatten_util import ravel_pytree
from scipy.optimize import minimize
import numpy as np # Required for interfacing with SciPy

# JAX configuration for 64-bit precision
jax.config.update("jax_enable_x64", True)

# ------------------------------------------------------------------
# 1. JAX-COMPATIBLE MPS AND MODEL FUNCTIONS
#    (Adapting the logic from the provided classes)
# ------------------------------------------------------------------

# --- Model Definition ---
def build_tfi_hamiltonian(L, J, g):
    """
    Creates the TFI Hamiltonian as a list of two-site bond operators.
    This is the JAX-native version of TFIModel.init_H_bonds.
    """
    d = 2
    sx = jnp.array([[0., 1.], [1., 0.]])
    sz = jnp.array([[1., 0.], [0., -1.]])
    id = jnp.identity(d)
    
    H_bonds = []
    for i in range(L - 1):
        # Handle open boundary conditions for the g-term
        gL = 0.5 * g
        gR = 0.5 * g
        if i == 0:
            gL = g
        if i + 1 == L - 1:
            gR = g
            
        H_bond = -J * jnp.kron(sx, sx) - gL * jnp.kron(sz, id) - gR * jnp.kron(id, sz)
        H_bonds.append(jnp.reshape(H_bond, (d, d, d, d)))
        
    return H_bonds

# --- MPS Helper Functions ---
def get_theta2(psi_params, i):
    """
    Calculate effective two-site wave function on sites i, j=(i+1).
    JAX-native version of MPS.get_theta2.
    
    Args:
        psi_params: A tuple (Bs, Ss) representing the MPS.
        i: The left site index of the bond.
        
    Returns:
        The two-site tensor with legs (vL, i, j, vR).
    """
    Bs, Ss = psi_params
    j = i + 1
    # theta1 = diag(S_i) @ B_i
    theta1 = jnp.tensordot(jnp.diag(Ss[i]), Bs[i], (1, 0)) # vL [vL'], [vL] i vR
    # theta2 = theta1 @ B_j
    theta2 = jnp.tensordot(theta1, Bs[j], (2, 0)) # vL i [vR], [vL] j vR
    return theta2

def bond_expectation_value(psi_params, H_bonds):
    """
    Calculate expectation values for each bond Hamiltonian term.
    JAX-native version of MPS.bond_expectation_value.
    
    Note: This calculation *assumes* the MPS is in right-canonical form.
    The gradient descent procedure may temporarily violate this, but for small
    steps, the results remain meaningful for optimization.
    """
    L = len(psi_params[0])
    exp_values = []
    for i in range(L - 1):
        theta2 = get_theta2(psi_params, i) # vL i j vR
        
        # op_theta = H_bond @ theta2
        op_theta = jnp.tensordot(H_bonds[i], theta2, axes=([2, 3], [1, 2])) # i j [i*] [j*], vL [i] [j] vR
        
        # <theta|op|theta>
        val = jnp.tensordot(jnp.conj(theta2), op_theta, axes=([0, 1, 2, 3], [2, 0, 1, 3])) # [vL*] [i*] [j*] [vR*], [i] [j] [vL] [vR]
        exp_values.append(val.real)
        
    return jnp.array(exp_values)

# --- Total Energy Function (to be differentiated) ---
@jax.jit
def compute_energy(psi_params, H_bonds):
    """
    Evaluate the total energy E = <psi|H|psi> for the given MPS.
    JAX-native version of TFIModel.energy.
    """
    return jnp.sum(bond_expectation_value(psi_params, H_bonds))

# ------------------------------------------------------------------
# 2. INITIALIZATION AND OPTIMIZATION SETUP
# ------------------------------------------------------------------

def init_spinup_mps(L, chi):
    """
    Initializes an all-spins-up product state as a right-canonical MPS.
    `chi` is the desired initial bond dimension. It will be padded with zeros.
    """
    B = jnp.zeros((1, 2, chi))
    B = B.at[0, 0, 0].set(1.0)
    
    # To make it right-canonical, contract with conjugate
    # A = B.transpose(1, 0, 2) # i, vL, vR
    # C = jnp.tensordot(jnp.conj(A), A, axes=([0,1], [0,1])) -> should be identity
    # Our B is already RC for chi=1
    B_rc = jnp.zeros((chi, 2, chi))
    B_rc = B_rc.at[0,0,0].set(1.0)
    
    Bs = [B_rc] * L
    
    # Ss[0] and Ss[L] are trivial boundary bonds of dim 1
    S = jnp.zeros(chi)
    S = S.at[0].set(1.0)
    Ss = [jnp.ones(1)] + [S] * (L-1) + [jnp.ones(1)]
    
    # A more general random starting point:
    key = jax.random.PRNGKey(42)
    Bs = [jax.random.normal(key, (chi, 2, chi)) for _ in range(L)]
    Ss = [jnp.ones(1)] + [jnp.ones(chi)/jnp.sqrt(chi) for _ in range(L-1)] + [jnp.ones(1)]

    # Important: Convert from numpy arrays to a pytree of jax arrays
    return jax.tree_util.tree_map(jnp.array, (Bs, Ss))


def optimize_mps(L, J, g, chi_max, tol=1e-9, max_iter=500):
    """
    Main function to run the gradient descent optimization.
    """
    # 1. Setup the model and initial state
    H_bonds = build_tfi_hamiltonian(L, J, g)
    psi_params_init = init_spinup_mps(L, chi_max)

    # 2. JAX utility to flatten/unflatten the MPS parameters for SciPy
    # This turns our list of tensors (Bs, Ss) into a single flat vector.
    flat_params, unflatten_fn = ravel_pytree(psi_params_init)

    # 3. Create the loss function and gradient function for SciPy
    @jax.jit
    def value_and_grad_fn(flat_params):
        # Unflatten the vector back into MPS tensors
        psi_params = unflatten_fn(flat_params)
        
        # The canonical form is slightly broken by the update.
        # As a simple correction, we re-normalize the Schmidt values at each step.
        psi_params_normalized = (psi_params[0], [s / jnp.linalg.norm(s) for s in psi_params[1]])
        
        energy, grads = jax.value_and_grad(compute_energy)(psi_params_normalized, H_bonds)
        
        # Flatten the gradients to match the input vector format
        flat_grads, _ = ravel_pytree(grads)
        return energy, flat_grads

    # Wrapper for SciPy, which expects numpy arrays and separate value/gradient
    def loss_for_scipy(x):
        e, g = value_and_grad_fn(x)
        # Store gradient in a variable to be retrieved by jac_for_scipy
        # This is a standard pattern to avoid re-computing
        loss_for_scipy.grad = np.array(g, dtype=np.float64)
        return np.array(e, dtype=np.float64)

    def jac_for_scipy(x):
        return loss_for_scipy.grad

    print("🚀 Starting MPS optimization with SciPy's L-BFGS-B...")
    print(f"L={L}, J={J}, g={g}, chi_max={chi_max}")
    
    # 4. Run the optimization
    res = minimize(
        fun=loss_for_scipy,
        x0=np.array(flat_params, dtype=np.float64),
        method='L-BFGS-B',
        jac=jac_for_scipy,
        tol=tol,
        options={'maxiter': max_iter, 'disp': True}
    )
    
    # 5. Return results
    final_energy = res.fun
    final_params_flat = res.x
    psi_final = unflatten_fn(final_params_flat)
    
    print("\n✅ Optimization finished.")
    return final_energy, psi_final

# ------------------------------------------------------------------
# 3. EXECUTION
# ------------------------------------------------------------------

if __name__ == '__main__':
    # --- Parameters ---
    L = 14
    J = 1.0
    g = 0.9  # Ferromagnetic phase
    # g = 1.0  # Critical point
    # g = 1.5  # Paramagnetic phase
    CHI_MAX = 8 # Maximum bond dimension

    # --- Run Optimization ---
    ground_energy, final_psi = optimize_mps(L, J, g, CHI_MAX)

    print(f"\nFinal Ground State Energy: {ground_energy:.8f}")
    
    # Compare with a high-accuracy DMRG result for verification
    # For (L=14, J=1, g=0.9), E ~ -13.023
    # For (L=14, J=1, g=1.0), E ~ -13.593
    # For (L=14, J=1, g=1.5), E ~ -17.514
    print("\nNote: The 'bond_expectation_value' formula assumes a canonical MPS.")
    print("The gradient descent finds a good variational minimum but doesn't strictly enforce the canonical form during optimization.")

🚀 Starting MPS optimization with SciPy's L-BFGS-B...
L=14, J=1.0, g=0.9, chi_max=8


  res = minimize(


TypeError: dot_general requires contracting dimensions to have the same shape, got (1,) and (8,).