In [None]:
import numpy as np
import torch
import torch.optim as optim
import matplotlib.pyplot as plt

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
class SoftSort(torch.nn.Module):
    """
    Class that implements differentiable soft sorting.
    """

    def __init__(self, tau: float = 1.0, hard: bool = False, power: float = 1.0):
        """
        Initialize the class
        :param tau: temperature parameter
        :param hard: whether to use soft or hard sorting
        :param power: power to use in the semi-metric d
        """
        super(SoftSort, self).__init__()
        self.hard = hard
        self.tau = tau
        self.power = power

    def forward(self, scores: torch.Tensor):
        """
        Forward pass of the class
        :param scores: The scores to be sorted
        :return: The softmax of sorted scores (descending sort)
        """
        scores = scores.unsqueeze(-1)
        sorted = scores.sort(descending=False, dim=1)[0]
        pairwise_diff = (scores.transpose(1, 2) - sorted).abs().pow(
            self.power
        ).neg() / self.tau
        P_hat = pairwise_diff.softmax(-1)

        if self.hard:
            P = torch.zeros_like(P_hat, device=P_hat.device)
            P.scatter_(-1, P_hat.topk(1, -1)[1], value=1)
            P_hat = (P - P_hat).detach() + P_hat
        return P_hat

In [None]:
def loss_non_diff(
    A: torch.Tensor,
    X: torch.Tensor,
    O: torch.Tensor,
    W: torch.Tensor,
    betas: torch.Tensor,
) -> (torch.Tensor, torch.Tensor):
    """
    Non differentiable loss function to measure cost (uses argsort)
    :param A: Actions to take for each variable
    :param X: The original values of each feature
    :param O: The initial ordering of the features
    :param W: The weighted adjacency matrix
    :param betas: The relative mutability of each feature
    :return: (X_bar, cost) where X_bar is the sorted values and cost total cost of applying A with ordering O.
    """
    cost = 0.0
    S = torch.argsort(O)
    X_bar = X.clone()
    # add 1 to the diagonal of W
    W = W + torch.eye(W.size(0))
    for i in S:
        X_bar += A[i] * W[:, i]
        cost += (A[i] ** 2) * torch.sigmoid(betas[i])
    return X_bar, cost

In [48]:
def loss_differentiable(
    A: torch.Tensor,
    X: torch.Tensor,
    O: torch.Tensor,
    W: torch.Tensor,
    betas: torch.Tensor,
    sorter,
) -> (torch.Tensor, torch.Tensor):
    """
    Differentiable loss function to measure cost (uses softsort)
    :param A: Actions to take for each variable
    :param X: The original values of each feature
    :param O: The initial ordering of the features
    :param W: The weighted adjacency matrix
    :param betas: The relative mutability of each feature
    :param sorter: The softsort-ing function
    :return: (X_bar, cost) where X_bar is the sorted values and cost total cost of applying A with ordering O.
    """
    # Number of individuals
    N = A.shape[0]

    # Initialize result tensors
    X_bars = torch.zeros(X.shape)
    W_temp = W + torch.eye(W.shape[0])
    if O.dim()==1:
        S = sorter(O.unsqueeze(0))
    else:
        S = sorter(O)
    cost = torch.zeros(X.shape[0])

    # Iterate over each row of A, X, and O
    for n in range(N):
        X_bar = X[n].clone()

        for i in range(W.shape[0]):
            X_bar += (W_temp * S[n][i]) @ A[n]
            cost[n] += torch.sum(A[n] ** 2 * S[n][i] * torch.sigmoid(betas))

        # Store results for this row
        X_bars[n] = X_bar

    return X_bars, cost

In [None]:
# SETTING UP PARAMETERS TO BE OPTIMIZED
A = torch.tensor([[0, 0, 0, 0], [0, 0, 0, 0]], dtype=torch.float32, requires_grad=True)
O = torch.rand(2, 4, dtype=torch.float32, requires_grad=True)
C = torch.rand(1, requires_grad=True, dtype=torch.float32)

# FIXED PARAMETERS
beta = torch.tensor([1, 1, 1, 1], dtype=torch.float32, requires_grad=True)
X = torch.tensor([[-1, -2, -5, -1], [-3, -6, 3, -6]], dtype=torch.float32)
W_adjacency = torch.tensor(
    [[0, 0, 0, 0], [0.3, 0, 0, 0], [0.2, 0, 0, 0], [0, 0.2, 0.3, 0]],
    dtype=torch.float32,
)
W_classifier = torch.tensor([2, 3, 1, 4], dtype=torch.float32)
sorter = SoftSort(tau=0.1, hard=True)

# Work out initial X_bar and cost
X_bar, cost = loss_differentiable(A, X, O, W_classifier, beta, sorter)

In [64]:
def optimise(
    max_optimiser: torch.optim,
    min_optimiser: torch.optim,
    n_epochs: int,
    A: torch.Tensor,
    X: torch.Tensor,
    O: torch.Tensor,
    C: torch.Tensor,
    W_adjacency: torch.Tensor,
    beta: torch.Tensor,
    sorter: SoftSort,
    classifier_margin: float,
):
    assert (
        classifier_margin >= 0
    ), "Classifier margin must be greater than or equal to 0"

    # Create lists
    objective_list = []
    constraint_list = []

    for i in range(n_epochs):
        # Maximise wrt C
        X_bar, cost = loss_differentiable(A, X, O, W_adjacency, beta, sorter)
        constraint = (
            loss_differentiable(A, X, O, W_adjacency, beta, sorter)[0] @ W_classifier
            - classifier_margin
        )
        max_loss = (C * constraint) - cost

        max_optimiser.zero_grad()
        max_loss.backward()
        max_optimiser.step()

        # Minimise wrt A, O, beta
        X_bar, cost = loss_differentiable(A, X, O, W_adjacency, beta, sorter)
        constraint = (
            loss_differentiable(A, X, O, W_adjacency, beta, sorter)[0] @ W_classifier
            - classifier_margin
        )
        min_loss = cost - (C * constraint)

        min_optimiser.zero_grad()
        min_loss.backward()
        min_optimiser.step()

        # Track objective and constraints
        objective_list.append(cost.item())
        constraint_list.append(constraint.item())

        # Early stopping
        if (
            i > 100
            and np.std(objective_list[-10:]) < 1e-4
            and np.std(constraint_list[-10:]) < 1e-4
        ):
            break
            
    # Print final ordering
    ordering = torch.max(sorter(O), dim=1)[1]

    # Return results
    return X_bar, ordering, cost, constraint

In [82]:
def optimise(
        max_optimiser: torch.optim,
        min_optimiser: torch.optim,
        n_epochs: int,
        A: torch.Tensor,
        X: torch.Tensor,
        O: torch.Tensor,
        C: torch.Tensor,
        W_adjacency: torch.Tensor,
        W_classifier: torch.Tensor,
        beta: torch.Tensor,
        sorter: SoftSort,
        classifier_margin: float,
):
    assert (
            classifier_margin >= 0
    ), "Classifier margin must be greater than or equal to 0"

    # Create lists
    objective_list = []
    constraint_list = []

    for i in range(n_epochs):
        # Maximise wrt C
        X_bar, cost = loss_differentiable(A, X, O, W_adjacency, beta, sorter)
        constraint = (
                loss_differentiable(A, X, O, W_adjacency, beta, sorter)[0] @ W_classifier
                - classifier_margin
        )
        max_loss = (C * constraint) - cost

        max_optimiser.zero_grad()
        max_loss.backward()
        max_optimiser.step()

        # Minimise wrt A, O, beta
        X_bar, cost = loss_differentiable(A, X, O, W_adjacency, beta, sorter)
        constraint = (
                loss_differentiable(A, X, O, W_adjacency, beta, sorter)[0] @ W_classifier
                - classifier_margin
        )
        min_loss = cost - (C * constraint)

        min_optimiser.zero_grad()
        min_loss.backward()
        min_optimiser.step()

        # Track objective and constraints
        objective_list.append(cost.item())
        constraint_list.append(constraint.item())

        # Early stopping
        if (
                i > 100
                and np.std(objective_list[-10:]) < 1e-4
                and np.std(constraint_list[-10:]) < 1e-4
        ):
            break

    # Print final ordering
    ordering = torch.max(sorter(O.unsqueeze(0)), dim=1)[1]

    # Return results
    return X_bar, ordering, cost, constraint + classifier_margin

def loss_differentiable(
        A: torch.Tensor,
        X: torch.Tensor,
        O: torch.Tensor,
        W: torch.Tensor,
        betas: torch.Tensor,
        sorter,
) -> (torch.Tensor, torch.Tensor):
    """
    Differentiable loss function to measure cost (uses softsort)
    :param A: Actions to take for each variable
    :param X: The original values of each feature
    :param O: The initial ordering of the features
    :param W: The weighted adjacency matrix
    :param betas: The relative mutability of each feature
    :param sorter: The softsort-ing function
    :return: (X_bar, cost) where X_bar is the sorted values and cost total cost of applying A with ordering O.
    """
    # Number of individuals
    N = A.shape[0]

    # Initialize result tensors
    X_bar = X.clone()
    W_temp = W + torch.eye(W.shape[0])
    if O.dim()==1:
        S = sorter(O.unsqueeze(0))[0]
    else:
        S = sorter(O)[0]
    cost = 0

    for i in range(W.shape[0]):
        X_bar += (W_temp * S[i]) @ A
        cost += torch.sum(A ** 2 * S[i] * torch.sigmoid(betas))

    return X_bar, cost

# SETTING UP PARAMETERS TO BE OPTIMIZED
As = [torch.tensor([0, 0, 0, 0], dtype=torch.float32, requires_grad=True) for _ in range(2)]
O = torch.rand(2, 4, dtype=torch.float32, requires_grad=True)
Os = [torch.rand(4, dtype=torch.float32, requires_grad=True) for _ in range(2)]
C = torch.rand(1, requires_grad=True, dtype=torch.float32)

# FIXED PARAMETERS
beta = torch.tensor([1, 1, 1, 1], dtype=torch.float32, requires_grad=True)
X = torch.tensor([[-1, -2, -5, -1], [-3, -6, 3, -6]], dtype=torch.float32)
W_adjacency = torch.tensor(
    [[0, 0, 0, 0], [0.3, 0, 0, 0], [0.2, 0, 0, 0], [0, 0.2, 0.3, 0]],
    dtype=torch.float32,
)
W_classifier = torch.tensor([2, 3, 1, 4], dtype=torch.float32)
sorter = SoftSort(tau=0.1, hard=True)

for i in range(2):

    max_optimiser = optim.SGD([C], lr=1e-2)
    max_optimiser.zero_grad()
    min_optimiser = optim.SGD(
        [
            {"params": [As[i]], "lr": 1e-2},
            {"params": [Os[i]], "lr": 1e-2},
        ]
    )
    min_optimiser.zero_grad()

    print(optimise(
        max_optimiser = max_optimiser,
        min_optimiser = min_optimiser,
        n_epochs = 2_000,
        A = As[i],
        X = X[i],
        O = Os[i],
        C = C,
        W_adjacency = W_adjacency,
        W_classifier = W_classifier,
        beta = beta,
        sorter = sorter,
        classifier_margin = 0.01,
    ))

(tensor([ 0.1744, -0.2081, -3.9317,  1.0533], grad_fn=<AddBackward0>), tensor([[3, 2, 0, 1]]), tensor(4.7096, grad_fn=<AddBackward0>), tensor(0.0056, grad_fn=<AddBackward0>))
(tensor([ 0.1085, -1.2570,  5.8278, -0.5651], grad_fn=<AddBackward0>), tensor([[0, 1, 3, 2]]), tensor(32.9984, grad_fn=<AddBackward0>), tensor(0.0138, grad_fn=<AddBackward0>))


In [88]:
O

tensor([[0.8037, 0.0326, 0.0135, 0.4695],
        [0.7466, 0.7928, 0.5113, 0.3977]], requires_grad=True)

In [96]:
torch.max(sorter(O), dim=1)[0].detach().squeeze()

tensor([[1., 1., 1., 1.],
        [1., 1., 1., 1.]])