In [51]:
import numpy as np
rng = np.random.default_rng()

In [54]:
def heatbath_update(U,beta):
    site = tuple(rng.integers(0,len(U),4))
    theta = sample_angle(beta)
    U[site] = np.exp(1j*theta)

def sample_angle(beta):
    alpha = np.sqrt(2*beta)*beta

    while True:
        Z = rng.uniform(0,1)
        x = -1 + np.log(1 + np.exp(2*alpha - 1)*Z)

        Q = np.exp(alpha*(np.cos(np.pi/2*(1-x))-x))
        Q_max = np.exp(0.2105137*alpha)

        Z_prime = rng.uniform(0,1)
        if Q/Q_max > Z_prime:
            angle = np.pi*(1-x)/2
            return angle

def run_heatbath(U, alpha, n, loop_sites_list):
    wilson_loop_sum = 0.0
    total_plaquettes = len(loop_sites_list)

    for _ in range(n):
        heatbath_update(U, alpha)

        # Sum Wilson loops for all plaquettes
        loop_value_sum = 0.0
        for loop_sites in loop_sites_list:
            loop_value_sum += wilson_loop(U, loop_sites)

        # Add the sum of loop values for this update
        wilson_loop_sum += loop_value_sum

    # Average Wilson loop value over all updates and all plaquettes
    return wilson_loop_sum / (n * total_plaquettes)

def create_plaquettes(width):
    loop_sites_list = []
    for x in range(width-1):   # Loop over the entire grid (excluding edges for plaquettes)
        for y in range(width-1):
            for z in range(width-1):
                for t in range(width-1):
                    loop_sites_list.append(create_plaquette(x, y, z, t))
    return loop_sites_list

def wilson_loop(U, loop_sites):
    loop_product = 1.0 + 0.0j
    for (x, y, z, t, mu) in loop_sites:
        loop_product *= U[(x, y, z, t)]
    return np.real(loop_product)

In [67]:
# Lattice size
width = 4
U = np.exp(2j * np.pi * np.random.rand(width, width, width, width))
alpha = 1.0
n_updates = 1000

# Create plaquettes for the entire grid
loop_sites_list = create_plaquettes(width)

# Run the heatbath update and compute the average Wilson loop over the whole grid
average_wilson_loop = run_heatbath(U, alpha, n_updates, loop_sites_list)
print(f"Average Wilson Loop for the entire heatbath: {average_wilson_loop}")

Average Wilson Loop for the entire heatbath: -0.08524628805463488
