In [None]:
import time
import numpy as np
from scipy.stats import cauchy
from scipy.integrate import solve_ivp
import matplotlib.pyplot as plt
from tqdm import tqdm  # Import tqdm for progress bar
import pickle  # To store data

# Define the RHS function
def rhs(t, z, J, K, L, n, nu, omega, mu):
    x = z[:n]
    theta = z[n:2*n]
    psi = z[2*n:3*n]
   
    xd = x[:, np.newaxis] - x
    theta_d = theta[:, np.newaxis] - theta
    psi_d = psi[:, np.newaxis] - psi

    x_rhs = -J * np.nan_to_num(np.sin(xd) * np.cos(theta_d) * np.cos(psi_d))
    theta_rhs = -K * np.nan_to_num(np.sin(theta_d) * np.cos(psi_d) * np.cos(xd))
    psi_rhs = -L * np.nan_to_num(np.sin(psi_d) * np.cos(xd) * np.cos(theta_d))

    x_next = np.nan_to_num(nu + (1/float(n)) * np.sum((1-np.eye(xd.shape[0])) * x_rhs, axis=1))
    theta_next = np.nan_to_num(omega + (1/float(n)) * np.sum((1-np.eye(xd.shape[0])) * theta_rhs, axis=1))
    psi_next = np.nan_to_num(mu + (1/float(n)) * np.sum((1-np.eye(xd.shape[0])) * psi_rhs, axis=1))
   
    return np.concatenate((x_next, theta_next, psi_next))

# Unpack function
def unpack(sols, n):
    T = len(sols)
    x = np.array(np.zeros((T, n)))
    theta = np.array(np.zeros((T, n)))
    psi = np.array(np.zeros((T, n)))
    for t in range(T):
        x[t] = sols[t, 0:n]
        theta[t] = sols[t, n:2*n]
        psi[t] = sols[t, 2*n:3*n]
    return [x, theta, psi]

# Compute order parameters
def find_Ws(x, theta, psi):
    numT, num_osc = x.shape
    W_1 = 1j * np.zeros(numT)
    W_2 = 1j * np.zeros(numT)
    W_3 = 1j * np.zeros(numT)
    W_4 = 1j * np.zeros(numT)
    for t in range(numT):
        W_1[t] = np.sum(np.exp(1j * (x[t, :] + theta[t, :] + psi[t, :]))) / float(num_osc)
        W_2[t] = np.sum(np.exp(1j * (x[t, :] + theta[t, :] - psi[t, :]))) / float(num_osc)
        W_3[t] = np.sum(np.exp(1j * (x[t, :] - theta[t, :] + psi[t, :]))) / float(num_osc)
        W_4[t] = np.sum(np.exp(1j * (x[t, :] - theta[t, :] - psi[t, :]))) / float(num_osc)
    return W_1, W_2, W_3, W_4

# Simulation parameters
dt, T, n = 0.1, 800, 200
np.random.seed(0)
x0 = np.random.uniform(-np.pi, np.pi, n)
theta0 = np.random.uniform(-np.pi, np.pi, n)
psi0 = np.random.uniform(-np.pi, np.pi, n)

t = np.arange(0, T, dt)
# nu, omega, mu = cauchy.rvs(size=n), cauchy.rvs(size=n), cauchy.rvs(size=n)
nu, omega, mu = 0, 0, 0  # Physical parametersC
z0 = np.concatenate([x0, theta0, psi0])

# Define ranges for J, K, and L
J_values = np.linspace(-2.5, 2.5, 50)  # J values from 1 to 60
K_values = np.linspace(-2.5, 2.5, 50)  # K values from 1 to 60
L_values = np.linspace(-2.5, 2.5, 50)  # L values from 1 to 60
heatmap = np.zeros((len(J_values), len(K_values), len(L_values)))

# Loop over J, K, and L to compute the heatmap with tqdm
tic = time.perf_counter()
for i, J in enumerate(tqdm(J_values, desc="Computing Heatmap for J")):
    for j, K in enumerate(tqdm(K_values, desc=f"Processing K for J={J:.2f}", leave=False)):
        for k, L in enumerate(tqdm(L_values, desc=f"Processing L for J={J:.2f}, K={K:.2f}", leave=False)):
            sol = solve_ivp(rhs, [0, T], z0, args=(J, K, L, n, nu, omega, mu), t_eval=t, method='RK45', rtol=1e-6, atol=1e-9)
            sols = sol.y.T
            x, theta, psi = unpack(sols, n)
            W_1, W_2, W_3, W_4 = find_Ws(x, theta, psi)
            # Compute the average sum of order parameters
            heatmap[i, j, k] = np.mean(np.abs(W_1) + np.abs(W_2) + np.abs(W_3) + np.abs(W_4))

toc = time.perf_counter()
print(f"Heatmap computation took {toc - tic:0.4f} seconds")

# Save the heatmap data for future use (using pickle)
with open('heatmap_data.pkl', 'wb') as f:
    pickle.dump({'J_values': J_values, 'K_values': K_values, 'L_values': L_values, 'heatmap': heatmap}, f)

# Create 3D scatter plot for J vs K vs L
J_index, K_index, L_index = np.meshgrid(range(len(J_values)), range(len(K_values)), range(len(L_values)), indexing='ij')
J_values_mesh = J_values[J_index]
K_values_mesh = K_values[K_index]
L_values_mesh = L_values[L_index]

# Flatten the arrays for scatter plot
x_vals = J_values_mesh.flatten()
y_vals = K_values_mesh.flatten()
z_vals = L_values_mesh.flatten()
color_vals = heatmap.flatten()  # Sum of order parameters

# Plotting 3D scatter plot
fig = plt.figure(figsize=(10, 8))
ax = fig.add_subplot(111, projection='3d')
scatter = ax.scatter(x_vals, y_vals, z_vals, c=color_vals, cmap='viridis', marker='o')
ax.set_xlabel('J')
ax.set_ylabel('K')
ax.set_zlabel('L')
ax.set_title('3D Heatmap of Sum of Order Parameters (J vs K vs L)')
fig.colorbar(scatter, label='Sum of Order Parameters')

plt.show()