In [1]:
import numpy as np
import matplotlib.pyplot as plt
from pathlib import Path
from dataclasses import dataclass
from sklearn.metrics import silhouette_score
from sklearn.preprocessing import StandardScaler

# ---------- Data container ----------
@dataclass
class ClusterData:
    red: np.ndarray
    blue: np.ndarray

    @property
    def X(self):
        return np.vstack((self.red, self.blue))

    @property
    def labels(self):
        return np.concatenate((np.zeros(len(self.red)),
                               np.ones(len(self.blue))))

# ---------- Synthetic generators ----------
def generate_mtms_data(n_points=400, seed=42) -> ClusterData:
    rng = np.random.default_rng(seed)
    red_centers = [(-20, 5), (0, 0), (25, 15), (20, -10)]
    blue_centers = [(-10, -5), (10, 10), (15, 0), (0, -15)]

    def clusters(centers):
        parts = [rng.normal((cx, cy), 5, size=(n_points // len(centers), 2))
                 for cx, cy in centers]
        return np.vstack(parts)

    return ClusterData(red=clusters(red_centers),
                       blue=clusters(blue_centers))

def generate_emtkd_data(n_points=400, seed=24) -> ClusterData:
    rng = np.random.default_rng(seed)
    angles = np.linspace(-np.pi/2, np.pi/2, n_points)
    radius = 30
    blue = np.c_[radius*np.cos(angles)+rng.normal(0, 2.5, n_points),
                 radius*np.sin(angles)+rng.normal(0, 2.5, n_points)]
    angles_r = rng.uniform(0, 2*np.pi, n_points)
    radii_r  = rng.uniform(5, 25, n_points)
    red = np.c_[radii_r*np.cos(angles_r)+rng.normal(0, 2.5, n_points),
                radii_r*np.sin(angles_r)+rng.normal(0, 2.5, n_points)]
    return ClusterData(red=red, blue=blue)

# ---------- Plotting ----------
def plot_clusters(data: ClusterData, title: str, file_path: Path):
    fig, ax = plt.subplots(figsize=(6, 6))
    ax.scatter(data.red[:, 0], data.red[:, 1], label="Dataset A", alpha=0.7)
    ax.scatter(data.blue[:, 0], data.blue[:, 1], label="Dataset B", alpha=0.7)
    ax.set(title=title, xlabel="Dimension 1", ylabel="Dimension 2")
    ax.legend(loc="upper left")
    fig.tight_layout()
    fig.savefig(file_path, format="pdf")
    plt.close(fig)

# ---------- Silhouette ----------
def silhouette(data: ClusterData) -> float:
    Xs = StandardScaler().fit_transform(data.X)
    return silhouette_score(Xs, data.labels)

# ---------- Main ----------
if __name__ == "__main__":
    out = Path("figures")
    out.mkdir(exist_ok=True)
    mtms_data  = generate_mtms_data()
    emtkd_data = generate_emtkd_data()

    plot_clusters(mtms_data,  "MTMS",  out / "figure3a_mtms.pdf")
    plot_clusters(emtkd_data, "EMTKD", out / "figure3b_emtkd.pdf")

    print(f"Silhouette (MTMS) : {silhouette(mtms_data):.3f}")
    print(f"Silhouette (EMTKD): {silhouette(emtkd_data):.3f}")


Silhouette (MTMS) : 0.049
Silhouette (EMTKD): 0.234
