<a href="https://colab.research.google.com/github/zfurman56/polytopes/blob/main/Polytopes.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## What are polytopes, and why should you care?

If you've spent any time doing linear algebra, you know that linear functions are especially easy to understand. Conveniently, ReLUs, the most common modern activation function, have a special property. ReLU neural networks aren't quite linear, but they're close - they're piecewise linear<sup>[1]</sup>. You might ask: can we understand neural networks in terms of their piecewise linear regions? This is exactly what the polytope view of neural networks (in other corners, the "spline theory" of neural networks) hopes to do.

This lets you think of ReLU networks as an intuitive process of iteratively folding the input, approximating functions with a kind of "origami." The goal of this notebook is to transfer this mental picture to you!

So what are polytopes? The *polytopes* are the linear regions. They're called polytopes because, as you'll see later, the border of these regions are all convex polygons - or, as mathematicians call them in dimensions higher than 2, *polytopes*.

For more information, it's highly recommended to read Conjecture's post on the topic. Or, if you have a more mathematical bent, you might also enjoy the original paper on the spline theory of deep learning.

<sub><sup>[1]: This isn't immediately obvious just from the fact that the ReLU function is piecewise linear, but you can quickly realize that adding and composing together multiple piecewise linear functions still leaves you with a piecewise linear function.</sub></sup>

## Let's get started

First, let's just make a simple single hidden layer ReLU network.

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import plotly.graph_objs as go
from scipy.ndimage import gaussian_filter

In [None]:
class FeedForwardNN(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super(FeedForwardNN, self).__init__()
        self.hidden1 = nn.Linear(input_size, hidden_size)
        self.hidden2 = nn.Linear(hidden_size, hidden_size)
        self.hidden3 = nn.Linear(hidden_size, hidden_size)
        self.output = nn.Linear(hidden_size, output_size)

    def forward(self, x):
        x = F.relu(self.hidden1(x))
        x = F.relu(self.hidden2(x))
        x = F.relu(self.hidden3(x))
        x = self.output(x)
        return x

Now, we'll train the network to approximate $f(x, y) = x+y^3$. We give it two inputs and one output, so we can make a nice 3D plot later.

In [None]:
# Define the target function
def target_function(x, y):
    return x + y**3

# Generate the dataset
n_samples = 1000
inputs = torch.rand(n_samples, 2) * 2 - 1  # Random samples in the range [-1, 1]
outputs = torch.tensor([[target_function(x, y)] for x, y in inputs], dtype=torch.float32)

# Define the network, loss function, and optimizer
input_size = 2
hidden_size = 10
output_size = 1

net = FeedForwardNN(input_size, hidden_size, output_size)
criterion = nn.MSELoss()
optimizer = optim.SGD(net.parameters(), lr=0.01, momentum=0.9)

# Train the network
n_epochs = 10000

for epoch in range(n_epochs):
    optimizer.zero_grad()
    predictions = net(inputs)
    loss = criterion(predictions, outputs)
    loss.backward()
    optimizer.step()

    if (epoch + 1) % 1000 == 0:
        print(f'Epoch [{epoch + 1}/{n_epochs}], Loss: {loss.item():.4f}')

print("Training completed.")

Epoch [100/10000], Loss: 0.0290
Epoch [200/10000], Loss: 0.0237
Epoch [300/10000], Loss: 0.0214
Epoch [400/10000], Loss: 0.0192
Epoch [500/10000], Loss: 0.0170
Epoch [600/10000], Loss: 0.0144
Epoch [700/10000], Loss: 0.0112
Epoch [800/10000], Loss: 0.0073
Epoch [900/10000], Loss: 0.0044
Epoch [1000/10000], Loss: 0.0029
Epoch [1100/10000], Loss: 0.0021
Epoch [1200/10000], Loss: 0.0017
Epoch [1300/10000], Loss: 0.0014
Epoch [1400/10000], Loss: 0.0012
Epoch [1500/10000], Loss: 0.0010
Epoch [1600/10000], Loss: 0.0008
Epoch [1700/10000], Loss: 0.0007
Epoch [1800/10000], Loss: 0.0007
Epoch [1900/10000], Loss: 0.0006
Epoch [2000/10000], Loss: 0.0006
Epoch [2100/10000], Loss: 0.0006
Epoch [2200/10000], Loss: 0.0006
Epoch [2300/10000], Loss: 0.0006
Epoch [2400/10000], Loss: 0.0005
Epoch [2500/10000], Loss: 0.0005
Epoch [2600/10000], Loss: 0.0005
Epoch [2700/10000], Loss: 0.0005
Epoch [2800/10000], Loss: 0.0005
Epoch [2900/10000], Loss: 0.0005
Epoch [3000/10000], Loss: 0.0005
Epoch [3100/10000],

Okay, let's check if it worked now:

In [None]:
# Example input
input_data = torch.tensor([[0.5, 0.7]])
output_data = net(input_data)

print("Output:", output_data[0,0])
print("Expected:", target_function(input_data[0,0], input_data[0,1]))

Output: tensor([[0.8418]], grad_fn=<AddmmBackward0>)


Looks good! Let's plot the function the function the network has learned. We can give different colors depending on the slope of function, to make the polytopes pop out.

In [None]:
# Create a high-resolution meshgrid for the input values
resolution = 200
x = np.linspace(-1, 1, resolution)
y = np.linspace(-1, 1, resolution)
x_grid, y_grid = np.meshgrid(x, y)

# Pass the input values through the network to get the output values
input_data = torch.tensor(np.array([x_grid.flatten(), y_grid.flatten()]).T, dtype=torch.float32)
input_data.requires_grad_(True)
output_data = net(input_data)
z_grid = output_data.detach().numpy().reshape(x_grid.shape)

# Calculate the gradients of the output with respect to the inputs
output_data.sum().backward()
grads = input_data.grad.detach().numpy()

# Find the unique gradient values and assign a color to each unique value
unique_grads, unique_indices = np.unique(grads, axis=0, return_inverse=True)
colors = np.array([np.arange(len(unique_grads)),] * len(grads)).flatten()[unique_indices]
colors_grid = colors.reshape(x_grid.shape)

# Normalize the colors to [0, 1] range
normalized_colors = (colors_grid - colors_grid.min()) / (colors_grid.max() - colors_grid.min())

# Apply Gaussian blur to smooth the edges between regions
smooth_colors = gaussian_filter(normalized_colors, sigma=1)

# Create a surface plot with different colors for each piecewise-linear region
fig = go.Figure(go.Surface(x=x_grid, y=y_grid, z=z_grid, surfacecolor=smooth_colors, colorscale='plasma'))

fig.update_layout(scene=dict(
        xaxis_title='Input x',
        yaxis_title='Input y',
        zaxis_title='Output z'))

fig.show()

Notice how the function looks like it's been made of a mesh, almost like a video game character. Each different colored region is one of our polytopes, corresponding to a unique activation pattern of neurons in the network.