In [None]:
import numpy as np
import numpy.typing as npt
import matplotlib.pyplot as plt
from matplotlib.axes import Axes
from pydantic import BaseModel
from pathlib import Path

# Helper class to represent results

In [None]:
class RunResult(BaseModel):
    name: str
    inputs: npt.NDArray  # (n, 2)
    x_hat: npt.NDArray  # (n, 2)
    router_logits: npt.NDArray  # (n, num_experts)
    expert_outputs: npt.NDArray  # (n, num_experts, 2)
    generator_function_name: str  # name of the function that generated the data
    num_experts: int  # number of experts
    epoch_nr: int  # epoch number
    
    class Config:
        arbitrary_types_allowed = True


    def get_router_probs(self) -> npt.NDArray:
        return np.exp(self.router_logits) / np.exp(self.router_logits).sum(axis=1, keepdims=True)

    def get_expert_rmse(self) -> npt.NDArray:
        """Get the RMSE per expert and sample"""
        return np.sqrt(np.mean((self.inputs[:, None] - self.expert_outputs) ** 2, axis=2))
    

    @classmethod
    def from_path(cls, data_path: Path) -> "RunResult":
        data = np.load(data_path)
        epoch_nr = int(data_path.stem.split("_")[-1])
        generator_function_name = data_path.parent.name.split("-")[0]
        num_experts = int(data_path.parent.name.split("-")[1])
        return cls(
            name=data_path.stem,
            inputs=data["inputs"],
            x_hat=data["x_hat"],
            router_logits=data["router_logits"],
            expert_outputs=data["expert_outputs"],
            generator_function_name=generator_function_name,
            num_experts=num_experts,
            epoch_nr=epoch_nr,
            
        )
    
    def matches(self, generator_function_name: str | None = None, num_experts: int | None = None, epoch_nr: int | None = None) -> bool:
        if generator_function_name is not None and self.generator_function_name != generator_function_name:
            return False
        if num_experts is not None and self.num_experts != num_experts:
            return False
        if epoch_nr is not None and self.epoch_nr != epoch_nr:
            return False
        return True
    
    def plot_reproductions(self, ax: Axes, expert_label_fmt="Predicted Expert {}") -> None:
        ax.set_aspect("equal")
        pred_expert_nr = self.router_logits.argmax(axis=1)
        ax.scatter(self.inputs[:, 0], self.inputs[:, 1], label="Input")
        num_experts = pred_expert_nr.max() + 1
        for i in range(num_experts):
            expert_mask = pred_expert_nr == i
            ax.scatter(self.x_hat[expert_mask, 0], self.x_hat[expert_mask, 1], label=expert_label_fmt.format(i))
        max_lines = 100
        step = max(len(self.inputs) // max_lines, 1)
        for i in range(0, len(self.inputs), step):
            ax.plot([self.inputs[i, 0], self.x_hat[i, 0]], [self.inputs[i, 1], self.x_hat[i, 1]], c="black", alpha=0.1)

    def plot_top_expert_predictions(self, ax: Axes, expert_label_fmt="Expert {}") -> None:
        ax.set_aspect("equal")
        top_expert_nr = self.router_logits.argmax(axis=1)
        ax.scatter(self.inputs[:, 0], self.inputs[:, 1], label="Input")
        num_experts = self.num_experts
        for i in range(num_experts):
            expert_mask = top_expert_nr == i
            ax.scatter(self.expert_outputs[expert_mask, i, 0], self.expert_outputs[expert_mask, i, 1], label=expert_label_fmt.format(i))
        max_lines = 100
        step = max(len(self.inputs) // max_lines, 1)
        top_preds = self.expert_outputs[np.arange(len(self.inputs)), top_expert_nr]
        for i in range(0, len(self.inputs), step):
            ax.plot([self.inputs[i, 0], top_preds[i, 0]], [self.inputs[i, 1], top_preds[i, 1]], c="black", alpha=0.1)
            


class RunResults(BaseModel):
    runs: list[RunResult]

    @classmethod
    def from_path(cls, data_path: Path) -> "RunResults":
        runs = [RunResult.from_path(p) for p in sorted(data_path.rglob("*.npz"))]
        return cls(runs=runs)
    
    def where(self, generator_function_name: str | None = None, num_experts: int | None = None, epoch_nr: int | None = None) -> "RunResults":
        runs = [r for r in self.runs if r.matches(generator_function_name, num_experts, epoch_nr)]
        return RunResults(runs=runs)
    
    def unique_generator_function_names(self) -> list[str]:
        return sorted(set([r.generator_function_name for r in self.runs]))
    
    def unique_num_experts(self) -> list[int]:
        return sorted(set([r.num_experts for r in self.runs]))
    
    def unique_epoch_nrs(self) -> list[int]:
        return sorted(set([r.epoch_nr for r in self.runs]))
    
    def __str__(self) -> str:
        return f"RunResults(len={len(self.runs)}, generator_function_names={self.unique_generator_function_names()}, num_experts={self.unique_num_experts()}, epoch_nrs={self.unique_epoch_nrs()})"
    
runs = RunResults.from_path(Path("models/"))
print(runs)

# Plot Expert outputs for circles

In [None]:
run = runs.where("circle", 4, 4).runs[0]
expert_outputs = run.expert_outputs
router_logits = run.router_logits
fig, ax = plt.subplots()
# set aspect to be equal
ax.set_aspect("equal")
plt.scatter(run.inputs[:, 0], run.inputs[:, 1], label="Input", c="C0")
# plot all preds as light grey
plt.scatter(expert_outputs[:, :, 0], expert_outputs[:, :, 1], label="non-top preds", c="#DDD", marker="x", zorder=-1)
for expert_nr in range(run.num_experts):
    is_top_expert = router_logits.argmax(axis=1) == expert_nr
    filtered_outputs = expert_outputs[is_top_expert]
    non_top_outputs = expert_outputs[~is_top_expert]
    # plt.scatter(non_top_outputs[:, expert_nr, 0], non_top_outputs[:, expert_nr, 1], label=f"non-top", c=f"#DDD", marker="x", zorder=-1)
    plt.scatter(filtered_outputs[:, expert_nr, 0], filtered_outputs[:, expert_nr, 1], label=f"Expert {expert_nr}", c=f"C{expert_nr+1}")

plt.legend()

# Plot Router probabilities

In [None]:
run = runs.where("circle", 4, 4).runs[0]
router_probs = run.get_router_probs()
expert_rmse = run.get_expert_rmse()
fig, axs = plt.subplots(1, 2, figsize=(10, 5))
for expert_nr in range(run.num_experts):
    # use color nr 1 for expert 0, color nr 2 for expert 1, etc
    axs[0].plot(router_probs[:, expert_nr], label=f"Expert {expert_nr} prob", c=f"C{expert_nr}", linestyle="-")
    axs[1].plot(expert_rmse[:, expert_nr], label=f"Expert {expert_nr} RMSE", c=f"C{expert_nr}", linestyle="-")

# plt.ylim(0, None)
plt.legend()

# Plot vanilla AE results

In [None]:
# plot 'vanilla' appraoch
# row: generator function
# n_rows=1
# epoch_nr=4
# num_experts=1

epoch_nr = 4
num_experts = 1
subfigsize = 5
n_rows = 1
n_cols = len(runs.unique_generator_function_names())
figsize = (n_cols * subfigsize, n_rows * subfigsize)
fig, axes = plt.subplots(n_rows, n_cols, figsize=figsize, squeeze=False)
axes = axes[0]
for i, generator_function_name in enumerate(runs.unique_generator_function_names()):
    run = runs.where(generator_function_name, num_experts, epoch_nr).runs[0]
    run.plot_top_expert_predictions(axes[i], expert_label_fmt="Predicted")
    axes[i].set_title(f"{generator_function_name}, vanilla AE, epoch {epoch_nr+1}")
    axes[i].legend()

Path("figures").mkdir(exist_ok=True)
plt.tight_layout()
plt.savefig("figures_top_experts/vanilla_approach.png")
plt.close()

# Plot MoE-AE representation evolution for circles
- row: epoch
- col: number of experts

In [None]:
# plot circles
# row: epoch nr
# column: num experts

generator_function_name = "circle"
n_rows = len(runs.unique_epoch_nrs())
n_cols = len(runs.unique_num_experts())
subfigsize = 5
figsize = (n_cols * subfigsize, n_rows * subfigsize)
fig, axs = plt.subplots(n_rows, n_cols, figsize=figsize, squeeze=False)
for i, epoch_nr in enumerate(runs.unique_epoch_nrs()):
    for j, num_experts in enumerate(runs.unique_num_experts()):
        ax = axs[i, j]
        run = runs.where(generator_function_name, num_experts, epoch_nr).runs[0]
        run.plot_top_expert_predictions(ax)
        ax.set_title(f"Epoch {epoch_nr+1}, {num_experts} Experts")
        ax.legend()

Path("figures").mkdir(exist_ok=True)

plt.tight_layout()
plt.savefig("figures_top_experts/gridplot-circles.png")
plt.close()

# Plot MoE-AE representation results after training
- row: data type
- col: number of experts

In [None]:
epoch_nr = 4
n_rows = len(runs.unique_generator_function_names())
n_cols = len(runs.unique_num_experts())
subfigsize = 5
figsize = (n_cols * subfigsize, n_rows * subfigsize)
fig, axs = plt.subplots(n_rows, n_cols, figsize=figsize, squeeze=False)
for i, generator_function_name in enumerate(runs.unique_generator_function_names()):
    for j, num_experts in enumerate(runs.unique_num_experts()):
        ax = axs[i, j]
        run = runs.where(generator_function_name, num_experts, epoch_nr).runs[0]
        run.plot_top_expert_predictions(ax)
        ax.set_title(f"{generator_function_name}, {num_experts} Experts")
        ax.legend()

plt.tight_layout()
plt.savefig("figures_top_experts/gridplot-functions.png")
plt.close()