In [None]:
%cd /constellaration
%load_ext autoreload
%autoreload 2

In [None]:

import matplotlib.pyplot as plt
from mpl_toolkits.axes_grid1.inset_locator import inset_axes, mark_inset
import numpy as np
from constellaration import forward_model, problems

from sklearn import decomposition, mixture, calibration

from constellaration.generative_model import bootstrap_dataset

from constellaration.generative_model import optimize_with_mcmc
import seaborn as sns


This file demonstrates the generative model approach for simple-to-build QI problem. 

### 1. Load data from HuggingFace and do PCA

In [None]:

DATASET_MAX_TOROIDAL_MODE = 4
MAX_POLOIDAL_MODE = 4
MAX_TOROIDAL_MODE = 4
N_FIELD_PERIODS = 3
SEED = 24
dframe = bootstrap_dataset.load_source_datasets_with_no_errors()
dframe = dframe[dframe["boundary.n_field_periods"] == N_FIELD_PERIODS]
dframe.head()
# dframe = bootstrap_dataset._unflatten_metrics_and_concatenate(dframe)
dframe = bootstrap_dataset._unserialize_surface(dframe)
dframe = bootstrap_dataset._augment_dataset(dframe)
print(f"Dataset size: {dframe.shape[0]}")

problem = problems.SimpleToBuildQIStellarator()
dframe.columns = [c.replace("metrics.", "") if c != "metrics.id" 
                  else c for c in dframe.columns]
dframe



In [None]:

problem = problems.SimpleToBuildQIStellarator()
# Relax problem constraints
relaxation_factor = 0.33
#problem._aspect_ratio_upper_bound = 10.0 + 10.0 * relaxation_factor
problem._edge_rotational_transform_over_n_field_periods_lower_bound =0.25 - 0.25 * relaxation_factor # noqa: E501
problem._log10_qi_upper_bound = -4.0 + 4.0 * relaxation_factor
problem._edge_magnetic_mirror_ratio_upper_bound = 0.2 + 0.2 * relaxation_factor
problem._max_elongation_upper_bound = 5.0 + 5.0 * relaxation_factor

X = bootstrap_dataset._to_X(
    dframe=dframe,
    max_poloidal_mode=MAX_POLOIDAL_MODE,
    max_toroidal_mode=MAX_TOROIDAL_MODE,
)
Y_constraints = bootstrap_dataset._to_Y_constraints(
    dframe=dframe,
    problem=problem,
)
Y_objective = bootstrap_dataset._to_Y_objective(
    dframe=dframe,
    problem=problem,
)
n_feasible_candidates_in_the_data = np.sum(np.all(Y_constraints <= 0, axis=1))
print(f"Feasible candidates in the data: {n_feasible_candidates_in_the_data}")
 # 1. PCA reduction
pca = decomposition.PCA(n_components=0.9998, whiten=True)
Z = pca.fit_transform(X)
print(f"Reduced to {Z.shape[1]} dimensions.")


### 2. Train GMM and random forest classifiers

In [None]:

def get_classifier_regressor_GMM(
    X: np.ndarray,
    Y_cons: np.ndarray,
    pca_explained_variance: float,
    n_estimators: int,
    seed: int,
) -> np.ndarray:
    n_feasible_candidates_in_the_data = np.sum(np.all(Y_cons <= 0, axis=1))
    print(f"Feasible candidates in the data: {n_feasible_candidates_in_the_data}")

    # 1. PCA reduction
    pca = decomposition.PCA(n_components=pca_explained_variance, whiten=True)
    Z = pca.fit_transform(X)
    print(f"Reduced to {Z.shape[1]} dimensions.")

   # 2) Train a classifier per constraint
    constraint_classifiers: list[calibration.CalibratedClassifierCV] = []
    for j in range(Y_cons.shape[1]):
        y_bin = (Y_cons[:, j] <= 0).astype(int)
        print(y_bin)
        print(f"Constraint {j}: {np.sum(y_bin)} feasible candidates")
        classifier = bootstrap_dataset._fit_calibrated_classifier(
            X=Z,
            y=y_bin,
            random_state=seed,
            n_estimators=n_estimators,
        )
        constraint_classifiers.append(classifier)
        # Uncommet for debugging
        calibration.CalibrationDisplay.from_estimator(
            estimator=classifier,
            X=Z,
            y=y_bin,
            n_bins=10,
        )
        plt.show()

    probabilities = np.vstack(
        [clf.predict_proba(X=Z)[:, 1] for clf in constraint_classifiers]
    )
    print(f"Probabilities: {probabilities}")
    is_feasible = np.all(probabilities >= 0.8, axis=0)
    print(f"Number of samples used to fit GMM: {len(np.where(is_feasible)[0])}")

    # 3. Initialize GMM on PCA space
    #gmm_n_components = bootstrap_dataset._n_components_that_minimizes_bic(Z, seed=seed)
    gmm_n_components = 24
    gmm = mixture.GaussianMixture(n_components=gmm_n_components, random_state=seed)
    gmm.fit(Z[is_feasible])
    #gmm.fit(Z)
    print(f"Fitted GMM with {gmm_n_components} components.")

    return gmm, pca, constraint_classifiers


In [None]:
gmm, pca, constraint_classifiers = get_classifier_regressor_GMM(
        X=X,
        Y_cons=Y_constraints,
        pca_explained_variance=0.9998,
        n_estimators=200,#100,
        seed=SEED,
    )


### 3. MCMC
Take the GMM as a prior, with a "quasi likelihood" to maximize the probability of feasible domain, obtain a posterior with MCMC. 

In [None]:
def log_prior(x):
    """Log prior for the GMM parameters."""
    #atleast 2d
    if x.ndim == 1:
        x = x.reshape(1, -1)
    return gmm.score_samples(x)
def quasi_log_likelihood(x):
    """Log likelihood for the GMM parameters."""
    if x.ndim == 1:
        x = x.reshape(1, -1)
    cons =[]
    for i,clf in enumerate(constraint_classifiers):
        p = clf.predict(X=x)
        cons.append(np.log(p + 1e-10))
    cons = np.array(cons)
    #cons = np.prod(cons, axis=0)
    cons = np.sum(cons, axis=0)
    # sum over the logp for three constraints
    return cons
def forward_(x): #noqa
    """Wrapper for the sphere function."""
    return None, None
history = {"x": [], "logp": []}  
def callback_print(state):
    """Callback function to save intermediate results."""
    # one can aslo save pickle files, or log to wandb.
    history["x"].append(state['current_sample'])
    history["logp"].append(state['logp'])
    if state["iteration"] % 100 == 0:
        print(f"Iteration: {state['iteration']}, logp: {state['logp']}")
probabilities = np.vstack(
        [clf.predict_proba(X=Z)[:, 1] for clf in constraint_classifiers]
    )
is_feasible = np.all(probabilities >= 0.995, axis=0)

Z[is_feasible].shape
# Run the optimization with mcmc
settings = optimize_with_mcmc.OptimizeWithMcmcSettings(
    num_samples=4000,
)


In [None]:

initial_guess = Z[is_feasible][0, :].reshape(-1)
initial_guess = initial_guess + initial_guess * 0.1 * np.random.randn(*initial_guess.shape) #noqa
mcmc_samples = optimize_with_mcmc.optimize_with_mcmc(
    function=forward_,
    x0=initial_guess,
    callback=callback_print,
    settings=settings,
    prior=log_prior,
    likelihood=quasi_log_likelihood,
)

plt.plot(history["logp"][1000:])
plt.xlabel("Iteration")
plt.ylabel("Log probability")
plt.grid()
plt.tight_layout()
plt.show()



In [None]:

# randombly pick points from last 4000 points from mcmc_samples
feasible_points = mcmc_samples[-4000:][np.random.choice(mcmc_samples[-4000:].shape[0], size=50, replace=False)] #noqa
metrics_list = []
x_feasible =pca.inverse_transform(feasible_points)
for x_hat in x_feasible[-40:,:]:
#for x_hat in pca.inverse_transform(Z[is_feasible][8:10,:]):
    surface = bootstrap_dataset._x_to_surface(x_hat,
                                          max_poloidal_mode=MAX_POLOIDAL_MODE,
                                          max_toroidal_mode=MAX_TOROIDAL_MODE,
                                          n_field_periods=N_FIELD_PERIODS,)
    try:
        metrics = forward_model.forward_model(
            boundary=surface,
            settings=forward_model.ConstellarationSettings()
        )[0]
        metrics_list.append(metrics)
        print(f"Max elongation: {metrics.max_elongation}")
    except Exception as _:
        print("Error in forward model")
        metrics_list.append(np.nan)
        continue
for metrics in metrics_list:
    #metrics.edge_rotational_transform_over_n_field_periods*=-1
    # take abs value
    metrics.edge_rotational_transform_over_n_field_periods = np.abs(metrics.edge_rotational_transform_over_n_field_periods) #noqa: E501


### 5. Compare with VMEC++ and plot results
Might not match with the paper figure, as the seed would have changed. 

In [None]:
# get feasible points in Z space in the dataset
idx_feasible= np.where(np.all(Y_constraints <= 0, axis=1))[0]
Z_feasible = Z[idx_feasible]

# Extract green points (where all constraints are satisfied)
c = np.all([
    (clf.predict(X=mcmc_samples) == 1) 
    & (clf.predict_proba(X=mcmc_samples)[:, 1] > 0.95)
    for clf in constraint_classifiers
], axis=0)

green_points = mcmc_samples[c]

# Create a KDE plot for the green points
fig, ax = plt.subplots()
sns.kdeplot(
    x=green_points[:, 0],
    y=green_points[:, 1],
    fill=True,
    cmap="Greens",
    ax=ax,
    alpha=0.6,
)
ax.set_xlabel("PC1")
ax.set_ylabel("PC2")

# Add the scatter plot on top of the KDE plot
ax.scatter(
    mcmc_samples[:, 0],
    mcmc_samples[:, 1],
    c=np.where(c, "green", "red"),
    s=1,
    label="MCMC Samples",
)

# Add green crosses for VMEC points
vmec_feasible = np.array([problem.is_feasible(metric) for metric in metrics_list])
# vmec_feasible = []
# for mm in metrics_list:
#     if isinstance(mm, numbers.Number) and np.isnan(mm):
#         vmec_feasible.append(False)
#     else:
#         vmec_feasible.append(problem.is_feasible(mm))
vmec_feasible = np.array(vmec_feasible)
ax.scatter(
    feasible_points[-40:, 0],
    feasible_points[-40:, 1],
    c=np.where(vmec_feasible, "green", "red"),
    marker="x",
    s=60,
    label="VMEC",
    linewidths=0.7,
)

# Add feasible points from the dataset
ax.scatter(
    Z_feasible[:, 0],
    Z_feasible[:, 1],
    c="blue",
    marker="+",
    s=30,
    label="Feasible Points in Dataset",
    linewidths=0.7,
)

#Set axis limits for the main plot. set limits based on the data.
ax.set_xlim(0.6, 1.2)
ax.set_ylim(0.5, 2.0)

# Add an inset plot
inset_ax = inset_axes(ax, width="55%", height="55%", loc="upper right")
sns.kdeplot(
    x=green_points[:, 0],
    y=green_points[:, 1],
    fill=True,
    cmap="Greens",
    ax=inset_ax,
    alpha=0.6,
)
inset_ax.scatter(
    mcmc_samples[:, 0],
    mcmc_samples[:, 1],
    c=np.where(c, "green", "red"),
    s=1.2,
)
inset_ax.scatter(
    feasible_points[-40:, 0],
    feasible_points[-40:, 1],
    c=np.where(vmec_feasible, "green", "red"),
    marker="x",
    s=80,
    linewidths=0.7,
)
inset_ax.scatter(
    Z_feasible[:, 0],
    Z_feasible[:, 1],
    c="blue",
    marker="+",
    s=40,
    linewidths=0.7,
)

# Set axis limits for the inset. Set the limits based on the data.
inset_ax.set_xlim(0.7, 0.8)
inset_ax.set_ylim(1.1, 1.3)

# Remove axis labels for the inset
inset_ax.set_xticks([])
inset_ax.set_yticks([])

# Add zoom lines connecting the inset to the main plot
mark_inset(ax, inset_ax, loc1=2, loc2=4, fc="none", ec="0.5", lw=1)

# Add labels, title, and legend
ax.set_xlabel("PC1")
ax.set_ylabel("PC2")
ax.legend(
    loc="lower right",  # Move legend to bottom-right
    frameon=True,       # Add a bounding box
    fontsize=20,        # Adjust font size
    fancybox=True,      # Rounded corners for the box
    shadow=True         # Add shadow to the box
)
plt.tight_layout()

plt.show()
