# Hard spheres and hexagonal symmetry
> psi-6 / $\psi_6$ is a measure for hexagonal symmetry in 2d. In this notebook set up different toy systems of hard spheres and perturb them, observing the decay of psi-6 from 1 to ~.4. Some configurations are exported to the [xyz](https://en.wikipedia.org/wiki/XYZ_file_format) format to be visualized with [ovito](https://www.ovito.org) (the basis version is enough).

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
from scipy.spatial.transform import Rotation
import numpy as np
import polars as pl
import seaborn as sns
import matplotlib.pyplot as plt
from dataclasses import dataclass
from pathlib import Path
import tqdm
from hardspheres_2d import (
    HardSpheres,
    generate_velocities,
    update_spheres,
    run_edmd,
    PerturbedSpheres,
)
from typing import Callable
from hardspheres_2d.psi6 import calc_psi6_bond_order_given_num_neighbors
from hardspheres_2d.hex import place_spheres_hexagonal, plot_spheres
from hardspheres_2d.dump import write_xyz_file
from hardspheres_2d.util import listify_array

## Single hexagonal ring

First let's create a rotation object

In [None]:
R = Rotation.from_euler("z", 60, degrees=True).as_matrix()[:2, :2]
R

now apply the rotation object to the `[0 1]` vector and create our ring of spheres forming a hexagon ring

In [None]:
_x = np.array([0.0, 1.0])

x_perfect_hexagon = np.array(
    [
        [0.0, 0.0],
        _x,
        R @ _x,
        R @ R @ _x,
        R @ R @ R @ _x,
        R @ R @ R @ R @ _x,
        R @ R @ R @ R @ R @ _x,
    ]
)
x_perfect_hexagon

In [None]:
def factory(v: list[list[float]]) -> Callable[[HardSpheres, np.ndarray], HardSpheres]:
    _m = listify_array(np.ones(shape=len(v)))

    def spheres_with_new_x(s: HardSpheres, new_x: np.ndarray) -> HardSpheres:
        return HardSpheres(
            x=listify_array(new_x),  # type: ignore
            v=v,  # type: ignore
            t=s.t,
            sigma=s.sigma,
            a=s.a,
            m=_m,  # type: ignore
            dt_snapshot=s.dt_snapshot,
            t_snapshot=s.t_snapshot,
        )

    return spheres_with_new_x

Let's instantiate our `HardSpheres` object

In [None]:
x = listify_array(x_perfect_hexagon)
v = listify_array(np.zeros_like(x_perfect_hexagon))
m = listify_array(np.ones(shape=x_perfect_hexagon.shape[0]))
s = HardSpheres(
    x=x,  # type: ignore
    v=v,  # type: ignore
    t=0.0,
    sigma=1.0,
    a=10,
    m=m,  # type: ignore
    dt_snapshot=0.01,
    t_snapshot=0.0,
)

spheres_with_new_x = factory(v)  # type: ignore

Let's define some helper function for psi-6 computation

In [None]:
def get_psi6(s: HardSpheres) -> np.ndarray:
    return calc_psi6_bond_order_given_num_neighbors(s, n_neighbors=6)


get_psi6(s)

### Perturb center sphere only

In [None]:
def get_perturbed_spheres(
    s: HardSpheres, n_samples: int, ixs_perturb: list[int], stds: list[float]
) -> list[PerturbedSpheres]:
    perturbed_spheres = []
    v = s.v
    spheres_with_new_x = factory(v)  # type: ignore
    for std in stds:
        for _ in range(n_samples):
            x = s.x
            for ix in ixs_perturb:
                dx = np.random.normal(scale=std, size=(2,))
                x[ix, :] += dx

            s_perturbed = spheres_with_new_x(s, x)

            psi6 = np.abs(get_psi6(s_perturbed))
            perturbed_spheres.append(
                PerturbedSpheres(spheres=s_perturbed, std=std, psi6=psi6)
            )
    return perturbed_spheres


n_samples = 10
ixs_perturb = [0]
stds = [0.001, 0.01, 0.02, 0.03, 0.04, 0.05, 0.06, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6]
perturbed_spheres = get_perturbed_spheres(s, n_samples, ixs_perturb, stds)

In [None]:
perturbed_spheres[-1]

Let's have a look

In [None]:
fig, ax = plt.subplots()
for _s in perturbed_spheres:
    if _s.std < 0.2:
        sns.scatterplot(
            x=_s.spheres.x[:, 0], y=_s.spheres.x[:, 1], ax=ax, color="red", alpha=0.4
        )
    else:
        sns.scatterplot(
            x=_s.spheres.x[:, 0], y=_s.spheres.x[:, 1], ax=ax, color="blue", alpha=0.05
        )

plt.tight_layout()

Neat, so scattering as expected, to a degree that is reasonable

Let's organize the data into a data frame

In [None]:
def dataframify_perturbed_spheres(
    perturbed_spheres: list[PerturbedSpheres],
) -> pl.DataFrame:
    return pl.DataFrame(
        [
            {
                "std": _s.std,
                "psi6": _s.psi6[ix_perturb],
                "sphere": ix_perturb,
                "iteration": iteration,
            }
            for iteration, _s in enumerate(perturbed_spheres)
            for ix_perturb in range(len(_s.psi6))
        ]
    )


df = dataframify_perturbed_spheres(perturbed_spheres)
df.head()

and inspect the psi-6 values as a function of perturbation of the center sphere

In [None]:
def plot_psi6_vs_std(df: pl.DataFrame, sphere: int | None = None):
    if sphere:
        df_stats = df.filter(pl.col("sphere").eq(sphere))
    else:
        df_stats = df.clone()

    df_stats = (
        df_stats.group_by("std")
        .agg(**{"psi6-mean": pl.col("psi6").mean(), "psi6-std": pl.col("psi6").std()})
        .with_columns(
            **{
                "lb": pl.col("psi6-mean") - pl.col("psi6-std"),
                "ub": pl.col("psi6-mean") + pl.col("psi6-std"),
            }
        )
    )
    df_stats = df_stats.sort("std")

    fig, ax = plt.subplots()
    ax.fill_between(data=df_stats, x="std", y1="lb", y2="ub", alpha=0.2)
    ax.plot(df_stats["std"], df_stats["psi6-mean"])
    ax.set_title(f"Sphere: {sphere}")
    plt.tight_layout()


plot_psi6_vs_std(df, sphere=0)

So in ideal condition we are at psi-6 = 1 and for stronger perturbations we land around .3

### Perturb one or more non-center spheres

In [None]:
n_samples = 10
ixs_perturb = [0, 1, 2, 3, 4, 5, 6]
stds = [0.001, 0.01, 0.02, 0.03, 0.04, 0.05, 0.06, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6]
perturbed_spheres_hex_random = get_perturbed_spheres(s, n_samples, ixs_perturb, stds)

In [None]:
df_all = dataframify_perturbed_spheres(perturbed_spheres_hex_random)
df_all.head()

In [None]:
fig, ax = plt.subplots()
for _s in perturbed_spheres_hex_random:
    if _s.std < 0.1:
        sns.scatterplot(
            x=_s.spheres.x[:, 0], y=_s.spheres.x[:, 1], ax=ax, color="red", alpha=0.4
        )
    else:
        sns.scatterplot(
            x=_s.spheres.x[:, 0], y=_s.spheres.x[:, 1], ax=ax, color="blue", alpha=0.05
        )

plt.tight_layout()

Isn't that pretty? :)

So how does the psi-6 for the center sphere behave when all spheres are perturbed?

In [None]:
plot_psi6_vs_std(df_all, sphere=0)

Actually quite similar it seems, with a faster drop from 1 to .3 though

So what's that like for the non-center spheres? E.g. sphere 1

In [None]:
plot_psi6_vs_std(df_all, sphere=1)

Well, not surprisingly it did not start at psi-6 = 1 and remains around .3. Not surprising because all its 6 neighbors are not surrounding it but of to one direction. Remember the system is a sphere surrounded by 6 spheres. sphere 1 is one of the spheres in the ring.

## Perfect hexagon

### Random peturbations

In [None]:
a = 10.0  # Edge length of the square
r = 0.5  # Radius of the spheres

sphere_centers = place_spheres_hexagonal(a, r, gap=0.1)

if sphere_centers.size > 0:
    plot_spheres(a, r, sphere_centers)
else:
    print("No spheres could be placed within the given dimensions.")

Isn't that pretty. Maybe stared at this for a bit :-P

Let's set up our `HardSpheres` 

In [None]:
s_perfect = HardSpheres(
    x=listify_array(sphere_centers),  # type: ignore
    v=listify_array(np.zeros_like(sphere_centers)),  # type: ignore
    t=0.0,
    sigma=r,
    a=a,
    m=listify_array(np.ones(shape=sphere_centers.shape[0])),  # type: ignore
    dt_snapshot=0.01,
    t_snapshot=0.0,
)

In [None]:
n_samples = 100
ixs_perturb = list(range(sphere_centers.shape[0]))
stds = [0.001, 0.01, 0.02, 0.03, 0.04, 0.05, 0.06, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6]
perturbed_spheres_hex_random = get_perturbed_spheres(
    s_perfect, n_samples, ixs_perturb, stds
)

In [None]:
df_hex_random = dataframify_perturbed_spheres(perturbed_spheres_hex_random)

In [None]:
plot_psi6_vs_std(df_hex_random)

So much more smooth than the single hexagon ring example, but same behavior.

Optionally write to disk for inspection in ovito

In [None]:
# sampled_spheres = [p.spheres for p in perturbed_spheres_hex_random]
# psi6_abs_n = [p.psi6 for p in perturbed_spheres_hex_random]

# file = Path("hardspheres-hexagon-sampling.xyz")

# write_xyz_file(
#     sampled_spheres,
#     file,
#     sphere_features={
#         "psi6_n_abs": psi6_abs_n,

#     },
# )

### Running EDMD

Let's run event driven molecular dynamics. So we do the same as above, but more physically motivated mumbling about temperature

In [None]:
a = 10.0  # Edge length of the square
a_ext = a * 1.1
r = 0.5  # Radius of the spheres

sphere_centers = place_spheres_hexagonal(a, r, gap=0.1)
plot_spheres(a_ext, r, sphere_centers, save=False)

In [None]:
m = np.ones(shape=sphere_centers.shape[0])

In [None]:
T = 0.01
v = generate_velocities(m, dim=2, T=T)

sns.histplot(x=v.ravel())

In [None]:
dt_snapshot = 0.1

In [None]:
s_edmd = HardSpheres(
    x=listify_array(sphere_centers),  # type: ignore
    v=listify_array(v),  # type: ignore
    t=0.0,
    sigma=r,
    a=a_ext,
    m=list(m),
    dt_snapshot=dt_snapshot,
    t_snapshot=0.0,
)

Now let's run edmd with `update_spheres`

In [None]:
n_iter = 1_000

history_spheres, history_psi6_abs_n, history_perturbed_spheres = run_edmd(
    s_edmd, n_iter=n_iter, T=T, progress=True, return_extra_info=True
)

Let's wrangle the data again

In [None]:
df_edmd = dataframify_perturbed_spheres(history_perturbed_spheres)

And visualize psi-6 statistics

In [None]:
def plot_psi6_vs_time(df: pl.DataFrame, save: bool = False):
    df_stats = (
        df.group_by("iteration")
        .agg(
            **{
                "psi6-mean": pl.col("psi6").mean(),
                "psi6-std": pl.col("psi6").std(),
                "psi6-max": pl.col("psi6").max(),
                "psi6-min": pl.col("psi6").min(),
            }
        )
        .with_columns(
            **{
                "lb": pl.col("psi6-mean") - pl.col("psi6-std"),
                "ub": pl.col("psi6-mean") + pl.col("psi6-std"),
            }
        )
    )
    df_stats = df_stats.sort("iteration")
    with sns.axes_style("whitegrid"):
        fig, ax = plt.subplots()
        ax.fill_between(
            data=df_stats,
            x="iteration",
            y1="lb",
            y2="ub",
            alpha=0.2,
            color="blue",
            label="mean+-std",
        )
        ax.plot(
            df_stats["iteration"], df_stats["psi6-mean"], color="blue", label="mean"
        )
        ax.plot(
            df_stats["iteration"],
            df_stats["psi6-max"],
            color="black",
            linestyle="--",
            label="max",
        )
        ax.plot(
            df_stats["iteration"],
            df_stats["psi6-min"],
            color="black",
            linestyle="dotted",
            label="min",
        )
        T = df["std"].unique().item()
        ax.set_title(f"Absolute psi6 value for 6 nearest neighbors @ {T=} over time")
        ax.set_ylim(0, 1)
        ax.set_xlabel("Iteration")
        ax.set_ylabel("|psi6|")
        ax.legend(loc="upper right")
        plt.tight_layout()
        if save:
            plt.savefig("hardspheres-hexagon-edmd-psi6-over-time.png")


plot_psi6_vs_time(df_edmd, save=False)

So similar behavior as with the random perturbations

In [None]:
# file = Path("hardspheres-hexagon-edmd.xyz")

# write_xyz_file(
#     history_spheres,
#     file,
#     sphere_features={
#         "psi6_n_abs": history_psi6_abs_n,

#     },
# )