# BCSoftmax Implementation

In [1]:
import torch
import numpy as np


def _bcsoftmax1d(x, budget):
    """Budget Constrained Softmax function for vector.

    Args:
        x (Tensor): input vector. shape: (n_outputs, )
        budget (Tensor): budget (constraint) vector. shape: (n_outputs, )

    Returns:
        y (Tensor): output probability vector. shape: (n_outputs, ). Satisfying the constraints y_i <= budget_i.
    
    """
    x = x - torch.max(x, dim=0)[0] # normalization to avoid numerical errors
    exp_x = torch.exp(x)
    # sorting
    _, indices = torch.sort(budget / exp_x, descending=False)
    exp_x = exp_x[indices]
    budget = budget[indices]
    # find K_B
    r = torch.sum(exp_x) - (torch.cumsum(exp_x, dim=0) - exp_x)
    s = 1.0 - (torch.cumsum(budget, dim=0) - budget)
    z = r/s
    is_in_KB = (s > 0) * (exp_x / z > budget)
    # compute outputs
    s = 1 - torch.sum(budget * is_in_KB)
    r = torch.sum(exp_x * (~is_in_KB))
    y = torch.where(~is_in_KB, s * exp_x / r, budget)
    # undo sorting
    _, inv_indices = torch.sort(indices, descending=False)
    return y[inv_indices]


class BCSoftmax1d(torch.autograd.Function):
    """Autograd implementation of Budget Constrained Softmax function for vector.
    """
    generate_vmap_rule = True
    
    @staticmethod
    def forward(x, c):
        y = _bcsoftmax1d(x, c)
        return y

    @staticmethod
    def setup_context(ctx, inputs, output):
        x, c = inputs
        is_in_KB = c == output
        ctx.save_for_backward(x, c, is_in_KB)
    
    @staticmethod
    def backward(ctx, grad_y):
        x, c, is_in_KB = ctx.saved_tensors
        
        exp_x = torch.exp(x)        
        s = 1 - torch.sum(c * is_in_KB)
        r = torch.sum(exp_x * (~is_in_KB))
        # compute Jacobian
        Jx = torch.where(
            torch.outer(~is_in_KB, ~is_in_KB),
            torch.diag(~is_in_KB * exp_x) * r - torch.outer(exp_x, exp_x),
            0,
        )
        Jx *= s / (r * r)
        Jc = torch.where(
            torch.outer(~is_in_KB, is_in_KB),
            - exp_x[:, None] / r,
            1 * torch.diag(is_in_KB)
        )
        # return vector-Jacobian product
        return torch.matmul(grad_y, Jx), torch.matmul(grad_y, Jc)


In [2]:
######### Use these functions! #########
bcsoftmax1d = BCSoftmax1d.apply
bcsoftmax2d = torch.vmap(BCSoftmax1d.apply) # input shape = (batch_size, n_classes)

In [3]:
def _bcsoftmax1d_naive(x, budget):
    """A naive implementation of bcsoftmax1d for testing.
    """
    x = x - torch.max(x, dim=0)[0] # normalization to avoid numerical errors
    exp_x = torch.exp(x)
    y = exp_x / torch.sum(exp_x)
    is_in_KB = torch.zeros_like(x, dtype=torch.bool)
    for _ in range(len(x)):
        is_in_KB = torch.logical_or(is_in_KB, y > budget)
        s = 1 - torch.sum(budget[is_in_KB])
        r = torch.sum(exp_x[~is_in_KB])
        y = torch.where(
            is_in_KB,
            budget,
            s * exp_x / r 
        )
    return y

batch_size = 32
n_classes = 10

for c in np.arange(1, 21) / 20.0:
    # generates data
    X = torch.randn(batch_size, n_classes, dtype=torch.double) * 2
    X = X.requires_grad_()

    while True:
        budget = c + ((1-c) * torch.rand(batch_size, n_classes, dtype=torch.double))
        budget = budget.requires_grad_()
        if torch.all(torch.sum(budget, dim=1) > 1):
            break    
    # forward testing: comparing with naive implementation
    actual = bcsoftmax2d(X, budget)
    expected = torch.vstack([_bcsoftmax1d_naive(x_i, budget_i) for x_i, budget_i in zip(X, budget)])
    torch.testing.assert_close(
        actual, expected,
    )
    assert torch.all(actual <= budget), "Budget Constraint Error"
    
    # backward testing: gradcheck
    torch.autograd.gradcheck(bcsoftmax2d, (X, budget))

# forward testing: bcsoftmax(x, budget) = softmax(x) if budget_i >= 1.0 for all i
actual = bcsoftmax2d(X, torch.ones_like(X))
torch.testing.assert_close(
    actual,
    torch.nn.functional.softmax(X, dim=1),
)

## MNIST Example

In [4]:
import argparse
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torch.optim.lr_scheduler import StepLR


class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, 3, 1)
        self.conv2 = nn.Conv2d(32, 64, 3, 1)
        self.dropout1 = nn.Dropout(0.25)
        self.dropout2 = nn.Dropout(0.5)
        self.fc1 = nn.Linear(9216, 128)
        self.fc2 = nn.Linear(128, 10)

    def forward(self, x, c):
        x = self.conv1(x)
        x = F.relu(x)
        x = self.conv2(x)
        x = F.relu(x)
        x = F.max_pool2d(x, 2)
        x = self.dropout1(x)
        x = torch.flatten(x, 1)
        x = self.fc1(x)
        x = F.relu(x)
        x = self.dropout2(x)
        x = self.fc2(x)
        output = torch.log(bcsoftmax2d(x, c))
        return output


def train(model, device, train_loader, optimizer, epoch, max_budget, log_interval):
    model.train()
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        output = model(data, torch.ones(len(target), 10).to(device) * max_budget)
        loss = F.nll_loss(output, target)
        loss.backward()
        optimizer.step()
        if batch_idx % log_interval == 0:
            print(
                'Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                    epoch,
                    batch_idx * len(data),
                    len(train_loader.dataset),
                    100. * batch_idx / len(train_loader),
                    loss.item())
             )
            

def test(model, device, test_loader, max_budget):
    model.eval()
    test_loss = 0
    correct = 0
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            output = model(data, torch.ones(len(target), 10).to(device) * max_budget)
            test_loss += F.nll_loss(output, target, reduction='sum').item()  # sum up batch loss
            pred = output.argmax(dim=1, keepdim=True)  # get the index of the max log-probability
            correct += pred.eq(target.view_as(pred)).sum().item()

    test_loss /= len(test_loader.dataset)

    print(
        '\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
            test_loss,
            correct,
            len(test_loader.dataset),
            100. * correct / len(test_loader.dataset))
    )

ModuleNotFoundError: No module named 'torchvision'

### Load dataset

In [6]:
use_cuda = torch.cuda.is_available()
use_mps = torch.backends.mps.is_available()

torch.manual_seed(0)

if use_cuda:
    device = torch.device("cuda")
elif use_mps:
    device = torch.device("mps")
else:
    device = torch.device("cpu")


transform=transforms.Compose(
    [
        transforms.ToTensor(),
        transforms.Normalize((0.1307,), (0.3081,))
    ]
)
dataset1 = datasets.MNIST(
    './data', train=True, download=True, transform=transform
)
dataset2 = datasets.MNIST(
    './data', train=False, transform=transform
)

### Run experiments with max_budget = 0.3, 0.6, 0.9

In [7]:
batch_size = 128
test_batch_size = 128
lr = 1.0
gamma = 0.7
epochs = 10
log_interval = 1000

for max_budget in [0.3, 0.6, 0.9]:
    print(f"Max Budget: {max_budget}")
    train_kwargs = {'batch_size': batch_size}
    test_kwargs = {'batch_size': test_batch_size}
    if use_cuda:
        cuda_kwargs = {
            'num_workers': 1,
            'pin_memory': True,
            'shuffle': True
        }
        train_kwargs.update(cuda_kwargs)
        test_kwargs.update(cuda_kwargs)
    
    train_loader = torch.utils.data.DataLoader(dataset1, **train_kwargs)
    test_loader = torch.utils.data.DataLoader(dataset2, **test_kwargs)
    
    model = Net().to(device)
    optimizer = optim.Adadelta(model.parameters(), lr=lr)
    
    scheduler = StepLR(optimizer, step_size=1, gamma=gamma)
    for epoch in range(1, epochs + 1):
        train(model, device, train_loader, optimizer, epoch, max_budget, log_interval)
        test(model, device, test_loader, max_budget)
        scheduler.step()
    print()

Max Budget: 0.3

Test set: Average loss: 1.2095, Accuracy: 5390/10000 (54%)


Test set: Average loss: 1.2076, Accuracy: 5964/10000 (60%)


Test set: Average loss: 1.2062, Accuracy: 5895/10000 (59%)


Test set: Average loss: 1.2064, Accuracy: 5615/10000 (56%)


Test set: Average loss: 1.2058, Accuracy: 5806/10000 (58%)


Test set: Average loss: 1.2057, Accuracy: 5861/10000 (59%)


Test set: Average loss: 1.2057, Accuracy: 5769/10000 (58%)


Test set: Average loss: 1.2057, Accuracy: 5698/10000 (57%)


Test set: Average loss: 1.2057, Accuracy: 5645/10000 (56%)


Test set: Average loss: 1.2059, Accuracy: 5589/10000 (56%)


Max Budget: 0.6

Test set: Average loss: 0.5319, Accuracy: 9778/10000 (98%)


Test set: Average loss: 0.5239, Accuracy: 9843/10000 (98%)


Test set: Average loss: 0.5236, Accuracy: 9861/10000 (99%)


Test set: Average loss: 0.5207, Accuracy: 9878/10000 (99%)


Test set: Average loss: 0.5206, Accuracy: 9886/10000 (99%)


Test set: Average loss: 0.5205, Accuracy: 9879/1000