In [None]:
import numpy as np
import numpy.typing as npt
import matplotlib.pyplot as plt
from matplotlib.axes import Axes
from pydantic import BaseModel
from pathlib import Path

# Helper class to represent results

In [None]:
class RunResult(BaseModel):
    name: str
    inputs: npt.NDArray
    x_hat: npt.NDArray
    router_logits: npt.NDArray
    expert_outputs: npt.NDArray
    generator_function_name: str
    num_experts: int
    epoch_nr: int
    
    class Config:
        arbitrary_types_allowed = True

    

    @classmethod
    def from_path(cls, data_path: Path) -> "RunResult":
        data = np.load(data_path)
        epoch_nr = int(data_path.stem.split("_")[-1])
        generator_function_name = data_path.parent.name.split("-")[0]
        num_experts = int(data_path.parent.name.split("-")[1])
        return cls(
            name=data_path.stem,
            inputs=data["inputs"],
            x_hat=data["x_hat"],
            router_logits=data["router_logits"],
            expert_outputs=data["expert_outputs"],
            generator_function_name=generator_function_name,
            num_experts=num_experts,
            epoch_nr=epoch_nr,
            
        )
    
    def matches(self, generator_function_name: str | None = None, num_experts: int | None = None, epoch_nr: int | None = None) -> bool:
        if generator_function_name is not None and self.generator_function_name != generator_function_name:
            return False
        if num_experts is not None and self.num_experts != num_experts:
            return False
        if epoch_nr is not None and self.epoch_nr != epoch_nr:
            return False
        return True
    
    def plot_predictions(self, ax: Axes, expert_label_fmt="Predicted Expert {}") -> None:
        ax.set_aspect("equal")
        pred_expert_nr = self.router_logits.argmax(axis=1)
        ax.scatter(self.inputs[:, 0], self.inputs[:, 1], label="Input")
        num_experts = pred_expert_nr.max() + 1
        for i in range(num_experts):
            expert_mask = pred_expert_nr == i
            ax.scatter(self.x_hat[expert_mask, 0], self.x_hat[expert_mask, 1], label=expert_label_fmt.format(i))
        max_lines = 100
        step = max(len(self.inputs) // max_lines, 1)
        for i in range(0, len(self.inputs), step):
            ax.plot([self.inputs[i, 0], self.x_hat[i, 0]], [self.inputs[i, 1], self.x_hat[i, 1]], c="black", alpha=0.1)


class RunResults(BaseModel):
    runs: list[RunResult]

    @classmethod
    def from_path(cls, data_path: Path) -> "RunResults":
        runs = [RunResult.from_path(p) for p in sorted(data_path.rglob("*.npz"))]
        return cls(runs=runs)
    
    def where(self, generator_function_name: str | None = None, num_experts: int | None = None, epoch_nr: int | None = None) -> "RunResults":
        runs = [r for r in self.runs if r.matches(generator_function_name, num_experts, epoch_nr)]
        return RunResults(runs=runs)
    
    def unique_generator_function_names(self) -> list[str]:
        return sorted(set([r.generator_function_name for r in self.runs]))
    
    def unique_num_experts(self) -> list[int]:
        return sorted(set([r.num_experts for r in self.runs]))
    
    def unique_epoch_nrs(self) -> list[int]:
        return sorted(set([r.epoch_nr for r in self.runs]))
    
    def __str__(self) -> str:
        return f"RunResults(len={len(self.runs)}, generator_function_names={self.unique_generator_function_names()}, num_experts={self.unique_num_experts()}, epoch_nrs={self.unique_epoch_nrs()})"
    
runs = RunResults.from_path(Path("models/"))
print(runs)

# Plot vanilla AE results

In [24]:
# plot 'vanilla' appraoch
# row: generator function
# n_rows=1
# epoch_nr=4
# num_experts=1

epoch_nr = 4
num_experts = 1
subfigsize = 5
n_rows = 1
n_cols = len(runs.unique_generator_function_names())
figsize = (n_cols * subfigsize, n_rows * subfigsize)
fig, axes = plt.subplots(n_rows, n_cols, figsize=figsize)
for i, generator_function_name in enumerate(runs.unique_generator_function_names()):
    run = runs.where(generator_function_name, num_experts, epoch_nr).runs[0]
    run.plot_predictions(axes[i], expert_label_fmt="Predicted")
    axes[i].set_title(f"{generator_function_name}, vanilla AE, epoch {epoch_nr+1}")
    axes[i].legend()

Path("figures").mkdir(exist_ok=True)
plt.tight_layout()
plt.savefig("figures/vanilla_approach.png")
plt.close()

# Plot MoE-AE representation evolution for circles
- row: epoch
- col: number of experts

In [25]:
# plot circles
# row: epoch nr
# column: num experts

generator_function_name = "circle"
n_rows = len(runs.unique_epoch_nrs())
n_cols = len(runs.unique_num_experts())
subfigsize = 5
figsize = (n_cols * subfigsize, n_rows * subfigsize)
fig, axs = plt.subplots(n_rows, n_cols, figsize=figsize)
for i, epoch_nr in enumerate(runs.unique_epoch_nrs()):
    for j, num_experts in enumerate(runs.unique_num_experts()):
        ax = axs[i, j]
        run = runs.where(generator_function_name, num_experts, epoch_nr).runs[0]
        run.plot_predictions(ax)
        ax.set_title(f"Epoch {epoch_nr+1}, {num_experts} Experts")
        ax.legend()

Path("figures").mkdir(exist_ok=True)

plt.tight_layout()
plt.savefig("figures/gridplot-circles.png")
plt.close()

# Plot MoE-AE representation results after training
- row: data type
- col: number of experts

In [26]:
epoch_nr = 4
n_rows = len(runs.unique_generator_function_names())
n_cols = len(runs.unique_num_experts())
subfigsize = 5
figsize = (n_cols * subfigsize, n_rows * subfigsize)
fig, axs = plt.subplots(n_rows, n_cols, figsize=figsize)
for i, generator_function_name in enumerate(runs.unique_generator_function_names()):
    for j, num_experts in enumerate(runs.unique_num_experts()):
        ax = axs[i, j]
        run = runs.where(generator_function_name, num_experts, epoch_nr).runs[0]
        run.plot_predictions(ax)
        ax.set_title(f"{generator_function_name}, {num_experts} Experts")
        ax.legend()

plt.tight_layout()
plt.savefig("figures/gridplot-functions.png")
plt.close()