In [1]:
import torch
import matplotlib.pyplot as plt
import numpy as np
import os


plt.rc('text', usetex=True)
# plt.rc('font', family='serif')
# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
cmap = 'viridis'

In [None]:
model_name = 'scan'
N = 100
c = 100
mu = 0.0
num_delta = 25
num_gamma = 25
num_trials = 10
folder_name = model_name + '_N_{}_c_{}_mu_{}_num_delta_{}_num_gamma_{}_num_trials_{}'.format(N, c, mu, num_delta, num_gamma, num_trials)
path = os.path.join('..', 'data', folder_name)

In [None]:
delta_range = np.linspace(0, 1, num_delta)
gamma_range = np.linspace(0.01, 1, num_gamma)

In [None]:
bool_stable = torch.load(os.path.join(path, 'bool_stable.pt'))

# import additional data
data_sim = np.loadtxt('../data/balanced_exp_N10_n100.txt')
data_a = np.loadtxt('../data/function_contours.txt')
data_y = np.loadtxt('../data/y_contours.txt')

In [None]:
# plot the stability phase diagram
# Calculate the percentage of stable circuits
percent_stable = bool_stable.float().mean(dim=2) * 100

# Plot the heatmap
plt.figure(figsize=(12, 10))
plt.imshow(percent_stable, extent=[gamma_range.min(), gamma_range.max(), delta_range.min(), delta_range.max()],
           origin='lower', aspect='auto', cmap='viridis')
colorbar = plt.colorbar(label="Percent Stable Circuits", fraction=0.046, pad=0.04)
colorbar.ax.tick_params(labelsize=14)

# plotting curve from flaviano
# plt.scatter(data_sim[:, 1], data_sim[:, 0], c='black', edgecolor='black', s=50, label="Simulation")
# plt.scatter(data_y[:, 0], data_y[:, 1], c='cyan', edgecolor='black', s=50, label="y-theory")
# plt.scatter(data_a[:, 0], data_a[:, 1], c='red', edgecolor='black', s=50, label="a-theory")

# add legend
plt.legend(fontsize=20, loc='lower right')


# Set font sizes
plt.xlabel(r'Contrast', fontsize=22)
plt.ylabel(r'$\Delta$', fontsize=22)
plt.title("Phase Diagram: Percent Stable Circuits", fontsize=20)

# Increase tick label size
plt.xticks(fontsize=20)
plt.yticks(fontsize=20)
colorbar.set_label("Percent Stable Circuits", fontsize=20)
# save the figure in svg
save_fig_path = os.path.join(path, 'percent_stable_circuits.svg')
plt.savefig(save_fig_path)
plt.show()