<a href="https://colab.research.google.com/github/vfrantc/quaternion_neurons/blob/main/visualize_quaternion_batchnorm.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip uninstall scikit-learn -y
!pip install -U scikit-learn
!pip install git+https://github.com/TParcollet/Quaternion-Neural-Networks.git

Found existing installation: scikit-learn 1.2.1
Uninstalling scikit-learn-1.2.1:
  Successfully uninstalled scikit-learn-1.2.1
Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting scikit-learn
  Using cached scikit_learn-1.2.1-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (9.8 MB)
Installing collected packages: scikit-learn
Successfully installed scikit-learn-1.2.1
Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting git+https://github.com/TParcollet/Quaternion-Neural-Networks.git
  Cloning https://github.com/TParcollet/Quaternion-Neural-Networks.git to /tmp/pip-req-build-bwxcj27e
  Running command git clone --filter=blob:none --quiet https://github.com/TParcollet/Quaternion-Neural-Networks.git /tmp/pip-req-build-bwxcj27e
  Resolved https://github.com/TParcollet/Quaternion-Neural-Networks.git to commit f8de5d5e5a3f9c694a0d62cffc64ec4ccdffd1bc
  Preparing metadata

In [None]:
import torch
import torch.nn as nn

import numpy as np
import matplotlib.pyplot as plt
from sklearn.datasets import load_iris
from sklearn.inspection import DecisionBoundaryDisplay


iris = load_iris()
feature_1, feature_2 = np.meshgrid(
    np.linspace(iris.data[:, 0].min(), iris.data[:, 0].max()),
    np.linspace(iris.data[:, 1].min(), iris.data[:, 1].max())
)
grid = np.vstack([feature_1.ravel(), feature_2.ravel()]).T

In [None]:
# Visualize convariance shift with BatchNormalization on an MLP
class MLP(nn.Module):
    def __init__(self, with_bn=False, **kwargs):
        super().__init__()
        self.with_bn = with_bn
        self.layer_1 = nn.Linear(2, 10)
        self.bn = nn.BatchNorm1d(10)
        self.layer_2 = nn.Linear(10, 3)
        self.relu = nn.ReLU()
        self.softmax = nn.Softmax(dim=1)

    def forward(self, x):
        x = self.layer_1(x)
        if self.with_bn:
            x = self.bn(x)
        x = self.relu(x)
        x = self.layer_2(x)
        x = self.softmax(x)
        return x

# Without BatchNormalization
mlp_without_bn = MLP()
mlp_without_bn.load_state_dict(torch.load('mlp_without_bn_parameters.pt'))
mlp_without_bn.eval()

# With BatchNormalization
mlp_with_bn = MLP(with_bn=True)
mlp_with_bn.load_state_dict(torch.load('mlp_with_bn_parameters.pt'))
mlp_with_bn.eval()

# Plot decision boundaries
fig, axes = plt.subplots(1, 2, figsize=(10, 5))
for i, net in enumerate([mlp_without_bn, mlp_with_bn]):
    decision_boundary = DecisionBoundaryDisplay(
        X=iris.data[:, 0:2], y=iris.target,
        predict_fn=lambda x: net(torch.tensor(x).float()).argmax(dim=1)
    )
    decision_boundary.plot(ax=axes[i])
    axes[i].set_title(f'MLP {["without BN", "with BN"][i]}')
fig.suptitle('Decision boundaries for MLP with and without BN', fontsize=20)
plt.show()

# Plot class probability distributions
fig, axes = plt.subplots(2, 3, figsize=(15, 10))
for i, net in enumerate([mlp_without_bn, mlp_with_bn]):
    class_probs = net(torch.tensor(grid).float()).detach().numpy()
    for j, class_name in enumerate(iris.target_names):
        ax = axes[i, j]
        ax.imshow(
            np.reshape(class_probs[:, j], feature_1.shape),
            extent=(
                iris.data[:, 0].min(),
                iris.data[:, 0].max(),
                iris.data[:, 1].min(),
                iris.data[:, 1].max()
            ),
            origin='lower'
        )
        ax.scatter(
            iris.data[iris.target == j, 0],
            iris.data[iris.target == j, 1],
            s=15, marker='o', c='white',
            alpha=0.7
        )
        ax.set_title(f'{class_name} probability')
fig.suptitle('Class probability distributions for MLP with and without BN', fontsize=20)
plt.show()

FileNotFoundError: ignored