In [2]:
import jax
import jax.numpy as jnp
from jax import grad, jit
from functools import partial

In [3]:
class MPS:
    """
    Matrix Product State in right-canonical form.
    Bs: list of site tensors with shape (Dl, d, Dr)
    Ss: list of singular values between sites (Schmidt values)
    """
    def __init__(self, Bs, Ss):
        self.Bs = Bs
        self.Ss = Ss
        self.L = len(Bs)

    @staticmethod
    def init_spinup(L):
        B = jnp.zeros((1, 2, 1))
        B = B.at[0, 0, 0].set(1.0)
        S = jnp.ones((1,))
        Bs = [B for _ in range(L)]
        Ss = [S for _ in range(L)]
        return MPS(Bs, Ss)

    def flatten(self):
        return jnp.concatenate([b.ravel() for b in self.Bs] + [s.ravel() for s in self.Ss])

    @staticmethod
    def unflatten(vec, shapes_B, shapes_S):
        Bs, Ss = [], []
        idx = 0
        for shape in shapes_B:
            size = jnp.prod(jnp.array(shape))
            Bs.append(vec[idx:idx+size].reshape(shape))
            idx += size
        for shape in shapes_S:
            size = jnp.prod(jnp.array(shape))
            Ss.append(vec[idx:idx+size].reshape(shape))
            idx += size
        return MPS(Bs, Ss)

    def get_theta2(self, i):
        S_diag = jnp.diag(self.Ss[i])
        theta1 = jnp.tensordot(S_diag, self.Bs[i], axes=[1, 0])
        theta2 = jnp.tensordot(theta1, self.Bs[i+1], axes=[2, 0])
        return theta2  # shape: (vL, i, j, vR)

    def bond_expectation_value(self, H_bonds):
        vals = []
        for i in range(self.L - 1):
            theta = self.get_theta2(i)  # (vL, i, j, vR)
            H = H_bonds[i]  # (i_out, j_out, i_in, j_in)
            tmp = jnp.tensordot(H, theta, axes=[[2,3],[1,2]])  # -> (i_out, j_out, vL, vR)
            tmp = jnp.tensordot(theta.conj(), tmp, axes=[[0,1,2,3],[2,0,1,3]])  # scalar
            vals.append(jnp.real(tmp))
        return jnp.sum(jnp.array(vals))


class TFIM:
    """
    Transverse Field Ising Model Hamiltonian: 
    H = -J sum sigma_x sigma_x - g sum sigma_z
    """
    def __init__(self, L, J, g):
        self.L = L
        self.J = J
        self.g = g
        self.d = 2 # local dimension (=2 for spin-1/2 of TFIM)
        self.sx = jnp.array([[0., 1.], [1., 0.]])
        self.sz = jnp.array([[1., 0.], [0., -1.]])
        self.id = jnp.eye(2)
        self.H_bonds = self.build_H_bonds()

    def build_H_bonds(self):
        sx, sz, I = self.sx, self.sz, self.id
        d = self.d
        H_list = []
        for i in range(self.L - 1):
            gL = self.g if i == 0 else 0.5 * self.g
            gR = self.g if i+1 == self.L-1 else 0.5 * self.g
            H = -self.J * jnp.kron(sx, sx) - gL * jnp.kron(sz, I) - gR * jnp.kron(I, sz)
            H = H.reshape(d, d, d, d)  # i_out, j_out, i_in, j_in
            H_list.append(H)
        return H_list

@partial(jit, static_argnums=(2,3))
def energy_expectation(vec, model, shapes_B, shapes_S):
    psi = MPS.unflatten(vec, shapes_B, shapes_S)
    return psi.bond_expectation_value(model.H_bonds)

# gradient of energy with respect to flattened parameters
# grad_E = grad(energy_expectation)


In [4]:
L = 10
J = 1.0
g = 1.5
model = TFIM(L, J, g)

psi = MPS.init_spinup(L)
shapes_B = [b.shape for b in psi.Bs]
shapes_S = [s.shape for s in psi.Ss]
vec = psi.flatten()

grad_fn = grad(lambda x: energy_expectation(x, model, shapes_B, shapes_S))

lr = 0.1
for step in range(50):
	g = grad_fn(vec)
	vec = vec - lr * g
	if step % 10 == 0:
		E = energy_expectation(vec, model, shapes_B, shapes_S)
		print(f"Step {step}: Energy = {E:.6f}")

psi_opt = MPS.unflatten(vec, shapes_B, shapes_S)

TypeError: Error interpreting argument to <function energy_expectation at 0x000001E3F4941DA0> as an abstract array. The problematic value is of type <class '__main__.TFIM'> and was passed to the function at path model.
This typically means that a jit-wrapped function was called with a non-array argument, and this argument was not marked as static using the static_argnums or static_argnames parameters of jax.jit.