## Testing Projection-as-a-Layer
This notebook implements projection to the simplex as a differentiable (?) layer.
The algorithm implemented was described in:
    " Efficient projections to the ell_1 ball for Learning in High Dimensions "
    by Duchi et al
and also in [this note](https://arxiv.org/pdf/1309.1541.pdf)

The code in this notebook draws heavily (some might even say copies) the code freely available [here](https://github.com/smatmo/ProjectionOntoSimplex)

In [33]:
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
#from FPN import FPN

import os
os.environ['KMP_DUPLICATE_LIB_OK']='True' # Weird hack I needed on my machine as PyTorch was crashing.

In [57]:
# class VINet(FPN):
#     '''
#         Test implementation of variational inequality net for learning to
#         solve a VI over the probability simplex.
        
#         WARNING: This code is, as yet, untested.
        
#         Daniel McKenzie, April 20th 2021
        
#     '''
#     def __init__(self, action_dim, num_players, device,
#                  s_hi=1.0, inf_dim=10):
#         super().__init__()
#         self._device = device
#         self._lat_dim = action_dim*num_players
#         self._inf_dim = inf_dim
#         self._device = device
        
#         # Layers
#         self.fc_u = nn.Linear(lat_dim, lat_dim, bias=False)
#         self.relu = nn.ReLU()
        
#     def name(self):
#         return 'VINet'
        
#     def device(self):
#         return self._device
    
#     def lat_dim(self):
#         return self._lat_dim
    
#     def s_hi(self):
#         return self._s_hi
    
#     def project_to_simplex(self, u):
#        """
#            function handling the projection to simplex.
#        """
#        batch_size = u.shape[0]
#        mu = torch.sort(u, descending=True)[0]
#        cum_sum = torch.cumsum(mu, dim=1)
#        # Don't actually need to track gradients in next step:
#        j = torch.unsqueeze(torch.arange(1,self._lat_dim + 1,
#                           dtype = mu.dtype, device = self._device),0)
#        rho = torch.sum(j*mu - cum_sum + 1. > 0.0,dim=1, keepdim=True) - 1.
#        rho = rho.long()
#        sum_to_rho = cum_sum[torch.arange(batch_size), rho[:,0]]
#        theta = (1 - torch.unsqueeze(sum_to_rho, -1))/(rho.type(sum_to_rho.dtype) + 1)
#        w = torch.clamp(theta + u, min=0.0)
#        return w
    
#     def latent_space_forward(self, u, v):
#         u = 0.99*self.relu(self.fc_u(u) + v)
        
#         # Now do projection on to simplex
        
#         w = self.project_to_simplex(u)
        
#         return w
       
        

In [58]:
device = 'cpu'
u = torch.tensor(np.random.randn(64,20), requires_grad=True)
batch_size = u.shape[0]
lat_dim = u.shape[1]

def project_simplex(u, lat_dim, device):
        """
            function handling the projection to simplex.
        """
        
        batch_size = u.shape[0]
        mu = torch.sort(u, descending=True)[0]
        cum_sum = torch.cumsum(mu, dim=1)
        # Don't actually need to track gradients in next step:
        j = torch.unsqueeze(torch.arange(1,lat_dim + 1,
                           dtype = mu.dtype, device = device),0)
        rho = torch.sum(j*mu - cum_sum + 1. > 0.0,dim=1, keepdim=True) - 1.
        rho = rho.long()
        sum_to_rho = cum_sum[torch.arange(batch_size), rho[:,0]]
        theta = (1 - torch.unsqueeze(sum_to_rho, -1))/(rho.type(sum_to_rho.dtype) + 1)
        w = torch.clamp(theta + u, min=0.0)
        return w

    
# Testing that Torch is tracking the gradients properly. It seems to be working!
w = project_simplex(u, lat_dim, device)
print(torch.sum(w))
print(w.requires_grad)

tensor(64., dtype=torch.float64, grad_fn=<SumBackward0>)
True


In [61]:
# Test the projection function
Data = torch.rand(128,20)
FixedPoints = project_simplex(torch.rand(128, 10), 10, device = "cpu")

print(torch.sum(FixedPoints[0,:]).item())
print(FixedPoints[1,:])

0.9999998211860657
tensor([0.0430, 0.0000, 0.1703, 0.2506, 0.2973, 0.0000, 0.0000, 0.2387, 0.0000,
        0.0000])


In [62]:
## Very simple Test test NN that uses projection to the simplex as its final layer.
# Note that this is an explicit, not implicit model! 
# Just wanted to test that we can back-prop through the projection layer.

class SimpleNet(nn.Module):
    """
        Simple Two layer network for testing
        
    """

    def __init__(self, data_dim, latent_dim):
        super(SimpleNet, self).__init__()
        self._data_dim = data_dim
        self._lat_dim = latent_dim
        self._device = "cpu"
        self.fc1 = nn.Linear(in_features = self._data_dim,
                             out_features = self._lat_dim, bias=False)
        self.relu = nn.ReLU()
        self.fc2 = nn.Linear(in_features = self._lat_dim,
                             out_features = self._lat_dim, bias = False)
        
    def project_to_simplex(self, u):
       """
           function handling the projection to simplex.
       """
       batch_size = u.shape[0]
       mu = torch.sort(u, descending=True)[0]
       cum_sum = torch.cumsum(mu, dim=1)
       # Don't actually need to track gradients in next step:
       j = torch.unsqueeze(torch.arange(1,self._lat_dim + 1,
                          dtype = mu.dtype, device = self._device),0)
       rho = torch.sum(j*mu - cum_sum + 1. > 0.0,dim=1, keepdim=True) - 1.
       rho = rho.long()
       sum_to_rho = cum_sum[torch.arange(batch_size), rho[:,0]]
       theta = (1 - torch.unsqueeze(sum_to_rho, -1))/(rho.type(sum_to_rho.dtype) + 1)
       w = torch.clamp(theta + u, min=0.0)
       return w
        
    def forward(self, u):
        u = self.fc1(u)
        u = self.relu(u)
        u = self.fc2(u)
        w = self.project_to_simplex(u)
        return w
        


In [64]:
# Testing training SimpleNet. The Data and labels (in this case, target fixed points) are randomly generated.
# Seems to work just fine, although it does learn to predict the uniform distribution given any input data,
# but perhaps this is the correct answer for random data?

Data = torch.rand(128,20)
FixedPoints = project_simplex(torch.rand(128, 10), 10, device = "cpu")

loss_func = nn.MSELoss()
SimpleNet1 = SimpleNet(20, 10)
print(SimpleNet1)
optimizer = optim.Adam(SimpleNet1.parameters(),lr=1e-2)

for epoch in range(20):
    prediction = SimpleNet1(Data)
    loss = loss_func(prediction, FixedPoints)
    print(loss.item())
    loss.backward()
    optimizer.step()


SimpleNet(
  (fc1): Linear(in_features=20, out_features=10, bias=False)
  (relu): ReLU()
  (fc2): Linear(in_features=10, out_features=10, bias=False)
)
0.025717373937368393
0.023060040548443794
0.021088670939207077
0.01971374824643135
0.01877906545996666
0.018249332904815674
0.017911424860358238
0.01774100959300995
0.017667662352323532
0.017655083909630775
0.017659761011600494
0.017658229917287827
0.01765677146613598
0.017650645226240158
0.0176463071256876
0.017657436430454254
0.017664175480604172
0.017684290185570717
0.017722895368933678
0.01776299439370632


In [69]:
# Test a few of simple net's predictions:
print(prediction[1,:])
print(prediction[8,:])

tensor([0.1000, 0.1000, 0.1000, 0.1000, 0.1000, 0.1000, 0.1000, 0.1000, 0.1000,
        0.1000], grad_fn=<SliceBackward>)
tensor([0.0198, 0.0996, 0.1261, 0.1039, 0.0000, 0.1828, 0.0000, 0.2279, 0.0326,
        0.2073], grad_fn=<SliceBackward>)
