<a href="https://colab.research.google.com/github/zfurman56/polytopes/blob/main/Polytopes.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## What are polytopes, and why should you care?

If you've spent any time doing linear algebra, you know that linear functions are especially easy to understand. Conveniently, ReLUs, the most common modern activation function, have a special property. ReLU neural networks aren't quite linear, but they're close - they're piecewise linear<sup>[1]</sup>. You might ask: can we understand neural networks in terms of their piecewise linear regions? This is exactly what the polytope view of neural networks (in other corners, the "spline theory" of neural networks) hopes to do.

This lets you think of ReLU networks as an intuitive process of iteratively folding the input, approximating functions with a kind of "origami." The goal of this notebook is to transfer this mental picture to you!

So what are polytopes? The *polytopes* are the linear regions. They're called polytopes because, as you'll see later, the border of these regions are all convex polygons - or, as mathematicians call them in dimensions higher than 2, *polytopes*.

For more information, I highly recommended reading [Conjecture's post](https://www.lesswrong.com/posts/eDicGjD9yte6FLSie/interpreting-neural-networks-through-the-polytope-lens) on the topic. Or, if you have a more mathematical bent, you might also enjoy the [original paper](https://proceedings.mlr.press/v80/balestriero18b.html) on the spline theory of deep learning.

<sub><sup>[1]: This isn't immediately obvious just from the fact that the ReLU function is piecewise linear, but you can quickly realize that adding and composing together multiple piecewise linear functions still leaves you with a piecewise linear function.</sub></sup>

## Let's get started

First, let's define our ReLU network. We give it a variable number of inputs, outputs, hidden layers, etc, because this flexibility will be useful later.

In [87]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import numpy as np
import plotly.graph_objs as go
import plotly.express as px
import copy
from scipy.ndimage import gaussian_filter
from ipywidgets import interact, FloatSlider

In [88]:
class FeedForwardNN(nn.Module):
    def __init__(self, input_size, hidden_size, output_size, hidden_layers):
        super(FeedForwardNN, self).__init__()
        self.layers = nn.ModuleList()

        # Input layer
        self.layers.append(nn.Linear(input_size, hidden_size))
        self.layers.append(nn.ReLU())

        # Hidden layers
        for _ in range(hidden_layers - 1):
            self.layers.append(nn.Linear(hidden_size, hidden_size))
            self.layers.append(nn.ReLU())

        # Output layer
        self.layers.append(nn.Linear(hidden_size, output_size))

    def forward(self, x):
        for layer in self.layers:
            x = layer(x)
        return x

Let's start with a network with a single input and a single output, since this is easiest to reason about:

In [89]:
# Define the target function
def target_function(x):
    return x**2

# Generate the dataset
n_samples = 1000
inputs = torch.rand(n_samples, 1) * 4 - 2  # Random samples in the range [-2, 2]
outputs = torch.tensor([[target_function(x)] for x in inputs], dtype=torch.float32)

# Define the network, loss function, and optimizer
input_size = 1
hidden_size = 10
output_size = 1
hidden_layers = 1

net = FeedForwardNN(input_size, hidden_size, output_size, hidden_layers)
criterion = nn.MSELoss()
optimizer = optim.SGD(net.parameters(), lr=0.01, momentum=0.9)

# Train the network
n_epochs = 10000

for epoch in range(n_epochs):
    optimizer.zero_grad()
    predictions = net(inputs)
    loss = criterion(predictions, outputs)
    loss.backward()
    optimizer.step()

    if (epoch + 1) % 1000 == 0:
        print(f'Epoch [{epoch + 1}/{n_epochs}], Loss: {loss.item():.4f}')

print("Training completed.")

Epoch [1000/10000], Loss: 0.0043
Epoch [2000/10000], Loss: 0.0020
Epoch [3000/10000], Loss: 0.0014
Epoch [4000/10000], Loss: 0.0013
Epoch [5000/10000], Loss: 0.0013
Epoch [6000/10000], Loss: 0.0013
Epoch [7000/10000], Loss: 0.0013
Epoch [8000/10000], Loss: 0.0013
Epoch [9000/10000], Loss: 0.0013
Epoch [10000/10000], Loss: 0.0013
Training completed.


In [90]:
# Generate input data (e.g., 200 points between -2 and 2)
num_points = 200
input_data = torch.linspace(-2, 2, num_points).unsqueeze(1)

# Compute the network output
with torch.no_grad():
    output_data = net(input_data)

# Convert the input and output data to lists
input_data_list = input_data.squeeze().tolist()
output_data_list = output_data.squeeze().tolist()

# Create a line plot using Plotly Express
fig = px.line(x=input_data_list, y=output_data_list, labels={'x': 'Input', 'y': 'Output'})
fig.show()

Looks like $x^2$ to me! Now let's give different colors depending on the slope of the function, so the polytopes pop out:

In [91]:
# Generate input data (e.g., 200 points between -2 and 2)
num_points = 200
input_data = torch.linspace(-2, 2, num_points).unsqueeze(1)
input_data.requires_grad_(True)

# Compute the network output
output_data = net(input_data)

# Calculate the gradients
output_data.sum().backward()

# Extract gradients from the input_data tensor
input_gradients = input_data.grad

# Find the piecewise linear regions
regions = []
prev_slope = None

for i in range(1, len(input_data)):
    curr_slope = input_gradients[i]

    if prev_slope is None or torch.isclose(curr_slope, prev_slope, rtol=1e-2, atol=1e-2):
        prev_slope = curr_slope
    else:
        regions.append(i - 1)
        prev_slope = curr_slope

regions.append(len(input_data) - 1)

# Plot the piecewise linear regions
fig = go.Figure()

for i in range(len(regions)):
    start = regions[i - 1] if i > 0 else 0
    end = regions[i] + 1

    fig.add_trace(go.Scatter(x=input_data[start:end].detach().numpy().flatten(),
                             y=output_data[start:end].detach().numpy().flatten(),
                             mode='lines',
                             name=f'Region {i + 1}'))

fig.update_layout(title='ReLU Neural Network Output',
                  xaxis_title='Input',
                  yaxis_title='Output')
fig.show()

Now let's give you a slider to play with the weight and bias of a particular neuron:

In [92]:
from google.colab import output
output.enable_custom_widget_manager()

# So we don't mess up the params of our original network
net_copy = copy.deepcopy(net)

# Generate input data
input_data = torch.linspace(-2, 2, 100).view(-1, 1).requires_grad_(True)

def update_plot(weight_div_bias, weight):
    # Update the weight and bias of a specific neuron
    neuron_idx = 0  # Index of the neuron to be updated
    bias = weight_div_bias * weight
    net_copy.layers[0].weight.data[neuron_idx] = weight
    net_copy.layers[0].bias.data[neuron_idx] = bias

    # Compute the network output
    with torch.no_grad():
        output_data = net_copy(input_data)
    
    # Convert the input and output data to lists
    input_data_list = input_data.squeeze().tolist()
    output_data_list = output_data.squeeze().tolist()
    
    # Create a line plot using Plotly Express
    fig = px.line(x=input_data_list, y=output_data_list, labels={'x': 'Input', 'y': 'Output'})
    fig.update_xaxes(range=[-2, 2])
    fig.update_yaxes(range=[-1, 5])
    fig.show()

# Create sliders to vary weight and bias
default_weight = net_copy.layers[0].weight.data[0].item()
default_bias = net_copy.layers[0].bias.data[0].item()
weight_slider = FloatSlider(value=default_weight, min=-2, max=2, step=0.1, description='Weight')
weight_div_bias_slider = FloatSlider(value=default_weight*default_bias, min=-10, max=10, step=0.1, description='Weight / bias')

# Connect sliders to the update_plot function
interact(update_plot, weight_div_bias=weight_div_bias_slider, weight=weight_slider)


interactive(children=(FloatSlider(value=0.7421781597873718, description='Weight / bias', max=10.0, min=-10.0),…

<function __main__.update_plot(weight_div_bias, weight)>

Notice how adjusting that neuron changed where the function "bends"? The weight changed how steep it bends, and the bias (weight divided by the bias, to normalize) changed where it bends. The neuron contributes nothing to the function output, until it activates, increasing the slope by a constant amount. So all a single hidden layer ReLU network is doing is determining where to add these bends, almost like origami.

This intution will generalize quite well to the 2D case. We'll train the network to approximate $f(x, y) = x^2+y^2$.

In [98]:
# Define the target function
def target_function(x, y):
    return x**2 + y**2

# Generate the dataset
n_samples = 1000
inputs = torch.rand(n_samples, 2) * 4 - 2  # Random samples in the range [-2, 2]
outputs = torch.tensor([[target_function(x, y)] for x, y in inputs], dtype=torch.float32)

# Define the network, loss function, and optimizer
input_size = 2
hidden_size = 15
output_size = 1
hidden_layers = 1

net = FeedForwardNN(input_size, hidden_size, output_size, hidden_layers)
criterion = nn.MSELoss()
optimizer = optim.SGD(net.parameters(), lr=0.01, momentum=0.9)

# Train the network
n_epochs = 10000

for epoch in range(n_epochs):
    optimizer.zero_grad()
    predictions = net(inputs)
    loss = criterion(predictions, outputs)
    loss.backward()
    optimizer.step()

    if (epoch + 1) % 1000 == 0:
        print(f'Epoch [{epoch + 1}/{n_epochs}], Loss: {loss.item():.4f}')

print("Training completed.")

Epoch [1000/10000], Loss: 0.0162
Epoch [2000/10000], Loss: 0.0140
Epoch [3000/10000], Loss: 0.0114
Epoch [4000/10000], Loss: 0.0099
Epoch [5000/10000], Loss: 0.0082
Epoch [6000/10000], Loss: 0.0066
Epoch [7000/10000], Loss: 0.0056
Epoch [8000/10000], Loss: 0.0051
Epoch [9000/10000], Loss: 0.0043
Epoch [10000/10000], Loss: 0.0040
Training completed.


Okay, let's check if it worked now:

In [99]:
# Example input
input_data = torch.tensor([[0.5, 0.7]])
output_data = net(input_data)

print("Output:", output_data[0,0])
print("Expected:", target_function(input_data[0,0], input_data[0,1]))

Output: tensor(0.7336, grad_fn=<SelectBackward0>)
Expected: tensor(0.7400)


Looks good! Let's plot the function the function the network has learned. We again give different colors depending on the slope of function, to show the polytopes.

In [100]:
# Create a high-resolution meshgrid for the input values
resolution = 200
x = np.linspace(-2, 2, resolution)
y = np.linspace(-2, 2, resolution)
x_grid, y_grid = np.meshgrid(x, y)

# Pass the input values through the network to get the output values
input_data = torch.tensor(np.array([x_grid.flatten(), y_grid.flatten()]).T, dtype=torch.float32)
input_data.requires_grad_(True)
output_data = net(input_data)
z_grid = output_data.detach().numpy().reshape(x_grid.shape)

# Calculate the gradients of the output with respect to the inputs
output_data.sum().backward()
grads = input_data.grad.detach().numpy()

# Find the unique gradient values and assign a color to each unique value
unique_grads, unique_indices = np.unique(grads, axis=0, return_inverse=True)
colors = np.array([np.arange(len(unique_grads)),] * len(grads)).flatten()[unique_indices]
colors_grid = colors.reshape(x_grid.shape)

# Normalize the colors to [0, 1] range
normalized_colors = (colors_grid - colors_grid.min()) / (colors_grid.max() - colors_grid.min())

# Apply Gaussian blur to smooth the edges between regions
smooth_colors = gaussian_filter(normalized_colors, sigma=1)

# Create a surface plot with different colors for each piecewise-linear region
fig = go.Figure(go.Surface(x=x_grid, y=y_grid, z=z_grid, surfacecolor=smooth_colors, colorscale='plasma'))

fig.update_layout(scene=dict(
        xaxis_title='Input x',
        yaxis_title='Input y',
        zaxis_title='Output z'))

fig.show()

You can see that, just like in the 1D case, the output was created from the input via a series of "bends" along creases. The function looks like it's been made of a mesh, almost like a video game character. The colored regions inside the creases are the polytopes.

A good exercise to try now is to adjust the number of hidden layers parameter - what will happen to the polytopes then? You might be able to notice a pattern.