# Pytorch example of polytope complexity
Notes:
- As a complexity measure, this may be used to measure complexity during training for early stopping reasons, or after training.
- For large samples sizes and networks, the current implementation may be slow.
- The number of polytopes may be highly sensitive to initialization, stopping criterion, or other randomness. So it is recommended to average across these and compute variances.
- In over-trained networks or very complex networks, it is often observed that each sample will lie in it's own polytope (saturation of the network). This may be due to either an overfit model or the prevelance of neighboring polytopes with similar linear functions in them due to ReLUs that don't meaningfully impact the network (imagine an outgoing weight of 0).
- One may also care about the distribution of samples in polytopes. That has been studied but is more complex of a subject.

In [2]:
import torch
import torch.nn as nn
from torch.utils.data import TensorDataset, DataLoader
import numpy as np
from sklearn.datasets import make_moons

In [140]:
batch_size = 20

X_train, y_train = make_moons(100)
my_dataset = TensorDataset(torch.Tensor(X_train), torch.Tensor(y_train).long()) # create your datset
train_loader = DataLoader(my_dataset, batch_size=batch_size) # create your dataloader

In [155]:
input_size = X.shape[1]
hidden_size = 5
num_classes = 2
num_epochs = 5
learning_rate = 0.001

# Fully connected neural network with two hidden layers
model = nn.Sequential(
    nn.Linear(input_size, hidden_size) ,
    nn.ReLU(),
    nn.Linear(hidden_size, hidden_size),
    nn.ReLU(),
    nn.Linear(hidden_size, num_classes) 
)

In [156]:
# Loss and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)  

# Train the model
total_step = len(train_loader)
for epoch in range(num_epochs):
    for i, (images, labels) in enumerate(train_loader):  
        # Forward pass
        outputs = model(images)
        loss = criterion(outputs, labels)
        
        # Backward and optimize
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

In [157]:
# Polytope functions
def get_polytopes(model, dataloader, penultimate=False):
    """
    Description
    -----------
    From a ReLU neural network and a training dataset, computes the number of polytopes
    occupied by the training samples relative to the sample size. Each polytope
    corresponds to a fixed assignment of all ReLU's as either on/off and within each
    polytope the network is a linear function. One may think of this fraction
    as the relative number of piecewise linear functions a network has to learn on the
    data in order to perform well.
    
    Parameters
    ----------
    model : pytorch sequential network with ReLUs
    dataloader : training data loader
    penultimate : boolean, default False
        If True, only returns polytopes using the last layer of ReLUs. Set to True
        if all the prior layers are viewed as a representation learner.
        
    Returns
    -------
    fraction : float
        Fraction of training samples in unique polytopes.
    polytope_assignments : list, length=dataloader_sample_size
        Labels encoding which samples occured in which polytopes.
    """
    all_memberships = []
    n_samples = 0
    
    for train_x, _ in dataloader:
        n_samples += train_x.shape[0]
        polytope_memberships = []
        for layer in model: # Assumes sequential, may have to adjust based on model
            train_x = layer(train_x)
            if type(layer) == nn.ReLU:
                binary_preactivation = (train_x.detach().numpy() > 0).astype('int')
                polytope_memberships.append(binary_preactivation)
    
        if penultimate:
            polytope_memberships = polytope_memberships[-1]
        else:
            polytope_memberships = np.hstack(polytope_memberships)
        all_memberships.append(polytope_memberships)
    
    polytopes, assignments = np.unique(np.vstack(all_memberships), axis=0, return_inverse=True)
    
    return len(polytopes) / n_samples, assignments

In [158]:
poly, labels = get_polytopes(model, train_loader, penultimate=False)

In [159]:
print(poly)
print(labels)

0.12
[ 4  6  9  0  0  0  0  6  0  9  8  2  1  8  0  0  4  1  5  0  0  6  0  5
  6  0  6  9  0  0  8  8 10  9  6  9  5  0  0  5  5  1  3  4  0  7  7  0
  0  0  7  0  0  0  7  2  5  0  9  4  3  7  9  2  0  2  8  2  8  7  4  7
  8  2  9  0  9  4  4  9  0  9 11  0  4  4  0  0  8  6  4  1  0  9  4  0
  6  0  4  7]
