Implements algorithm 3 (ModMod) from Algorithms for "Approx Min of the Difference Between Submodular Fncs with Applications."

Includes submodular set functions: linear, budget-additive, and entropy

For functions where V and X are sets, see () <br>
For functions where V and X are torch.Tensors as in entropy, see (1)

In [1]:
import torch
from torch import distributions as dist

import itertools
import pickle

import itertools
import random
import math
import numpy as np
import pandas as pd
import scipy

import matplotlib.pyplot as plt
import seaborn as sns
%matplotlib inline
sns.set_style('white')
sns.set_context('paper')
# Plot adjustments:
plt.rcParams.update({'ytick.labelsize': 15})
plt.rcParams.update({'xtick.labelsize': 15})
plt.rcParams.update({'axes.labelsize': 35})
plt.rcParams.update({'legend.fontsize': 30})
plt.rcParams.update({'axes.titlesize': 16})

from gptorch import kernels, models

In [2]:
def mod_lower(X, fn, perm, *args, **kwargs):
    """ Modular lower bound of fn(X) for any X contained in ground set V
    with permutation chain perm (aka S).
    
    Expects X as a set, fn as a Python function, and perm as a list. """
    
    low = 0.0 # lower modular bound
    
    for elem in X:
        i = perm.index(elem)
        if i == 0:
            low += fn({perm[0]}, *args, **kwargs)
        else:
            low += fn(set(perm[0:i + 1]), *args, **kwargs) - fn(set(perm[0:i]), *args, **kwargs)
            
    return low

def mod_lower1(X, fn, perm, *args, **kwargs):
    """ Modular lower bound of fn(X) for any X contained in ground set V
    with permutation chain perm (aka S).
    
    Expects X as a set, fn as a Python function, and perm as a list. """
    
    
    low = 0.0 # lower modular bound
    if len(X) == 0:
        return(fn(torch.Tensor([]), *args, **kwargs))
    
    for elem in X:
        i = (perm == elem).nonzero()[0][0].data.numpy() #perm.index(elem)
        if i == 0:
            low += fn(perm[torch.Tensor([0]).long()], *args, **kwargs)
        else:
            low += fn(perm[0:i + 1], *args, **kwargs) - fn(perm[0:i], *args, **kwargs)
            
    return low

In [3]:
def mod_upper(X, fn, center, *args, **kwargs):
    """ Modular upper bound of fn(X) for any X contained in ground set V, 
    centered at center.
    
    Expects X and center as sets, and fn as a Python function. """
    
    up = fn(center, *args, **kwargs) # modular upper bound
    
    for j in center:
        if j not in X:
            center_noj = [x for x in center if x != j]
            up -= fn(center, *args, **kwargs) - fn(center_noj, *args, **kwargs)
 
    for j in X:
        if j not in center:
            up += fn({j}, *args, **kwargs) - fn({}, *args, **kwargs)

    return up

def mod_upper1(X, fn, center, *args, **kwargs):
    """ Modular upper bound of fn(X) for any X contained in ground set V, 
    centered at center.
    
    Expects X and center as sets, and fn as a Python function. """
    
    up = fn(center, *args, **kwargs) # modular upper bound
    
    for j in center:
        if len(X) > 0 and (X == j).nonzero().size()[0] == 0:
            no_j = [x for x in center if torch.eq(x, j)[0] != 1]
            if no_j == []:
                center_noj = torch.Tensor(no_j)
            else:
                center_noj = torch.stack(no_j)
            up -= fn(center, *args, **kwargs) - fn(center_noj, *args, **kwargs)
        elif len(X) == 0:
            continue

    for j in X:
        if len(center) == 0 or (center == j).nonzero().size()[0] == 0:
            up += fn(torch.stack([j]), *args, **kwargs) - fn(torch.zeros(1, X.size()[1]), *args, **kwargs)

    return up

In [4]:
def make_perm(V, X):
    """ Takes in the ground set V and a set X, and
    returns a random chain permutation that contains X """
    lst_perms = list(itertools.permutations(V)) # create a list of all possible permutations of V
    
    # select possible perms that contain X and include all elements in V
    poss_perms = [perm for perm in lst_perms if X.issubset(perm[0:len(X)]) and len(perm) == len(V)]
    
    return list(random.choice(poss_perms)) # return random perm from list of possible perms

def make_perm1(V, X):
    """ Takes in the ground set V and a set X, and
    returns a random chain permutation that contains X """
    if len(X) == 0:
        indices = list(range(len(V)))
        random.shuffle(indices)
        return torch.stack([V[i] for i in indices])
    
    ind_X = [i for i, v in enumerate(V) if (X == v).nonzero().size()[0] != 0] # indices of X in V
    rest = [i for i in list(range(len(V))) if i not in ind_X] # rest of indices in V
    
    random.shuffle(ind_X) # shuffle indices
    random.shuffle(rest)
    indices = ind_X + rest # combine
    
    return torch.stack([V[i] for i in indices]) # generate perm based on shuffled indices

In [5]:
def mod_mod(V, fn, g):
    """ Implements algorithm3 (ModMod) from paper. Takes in ground set V,
    and functions fn and g. 
    
    Expects V as a set, and fn and g as submodular Python functions. """
    X, X_next = set(), set()
    while True:
        perm = make_perm(V, X) # choose permutation
        for i in V:
            obj = mod_upper({i}, fn, X) - mod_lower({i}, g, perm) # objective with element in V as input
            emp = mod_upper({}, fn, X) - mod_lower({}, g, perm) # objective with empty set as input
            if obj < emp:
                X_next.add(i)
        if X_next == X:
            break
        else:
            X = X_next
            X_next = set()
    
    return X_next

def mod_mod1(V, fn, g, *args, **kwargs):
    """ Implements algorithm3 (ModMod) from paper. Takes in ground set V,
    and functions fn and g. 
    
    Expects V as a torch Tensor matrix, and fn and g as submodular Python functions. """

    X, X_next, empty = torch.Tensor([]), torch.Tensor([]), torch.Tensor([])
    
    while True:
        perm = make_perm1(V, X) # choose permutation
        for i in V:
            obj = mod_upper1(torch.stack([i]), fn, X, args[0]) - mod_lower1(torch.stack([i]), g, perm, args[1]) # objective with element in V as input
            emp = mod_upper1(empty, fn, X, args[0]) - mod_lower1(empty, g, perm, args[1]) # objective with empty set as input
            if obj <= emp:
                if len(X_next) == 0:
                    X_next = torch.stack([torch.cat((X_next, i), 0)])
                else:
                    X_next = torch.stack([X_next[0], i])
        if X.size() == X_next.size() and torch.eq(X, X_next).all():
            break
        else:
            X = X_next
            X_next = torch.Tensor([])
    
    return X_next

In [6]:
# two examples of submodular set functions

def linear(X):
    """ A function f : 2^N → R is linear if f(X) = sum (i∈X) of wi for some weights
    w : N → R. If wi ≥ 0 for all i ∈ N, then f is also monotone. """
    
    if len(X) == 0: # empty set
        return 0
    
    sum_w = 0
    for i in X:
        sum_w += i
        
    return sum_w
    
def budget_add(X, B=10):
    """ A small generalization of the linear case, the function f(A) =
    min{sum (i∈A) of wi, B} for any wi ≥ 0 and B ≥ 0, is monotone submodular."""
    
    if len(X) == 0: # empty set
        return 0
    
    sum_w = 0
    for i in X:
        sum_w += i
    
    return min(sum_w, B)

In [7]:
# entropy function

def entropy(X, k):
    """ Takes in a set of points X, covariance function cov, 
    and computes the entropy of the multivariate normal distribution """
    if len(X) == 0:
        return torch.Tensor([0.0])[0]
    
    cov = k(X, X)
    return torch.log(torch.det(cov)) # 0.5 * torch.log(torch.det(2 * math.pi * torch.exp(1) * cov))

In [None]:
V = torch.randn(10, 4)
ke = kernels.SEKernel()
ke2 = kernels.MaternKernel()

In [8]:
V = torch.randn(100, 4)
ke = kernels.SEKernel()
ke2 = kernels.MaternKernel()

# each row of X is an element of the set of X
inds = np.random.choice(100, size=5, replace=False)
inds = torch.LongTensor(inds)
X = V[inds]
inds = np.random.choice(100, size=5, replace=False)
inds = torch.LongTensor(inds)
center = V[inds]

In [None]:
entropy(X, ke)

In [None]:
entropy(X, ke2)

In [None]:
perm = make_perm1(V, center)
mod_lower1(X, entropy, perm, ke2)

In [9]:
mod_upper1(X, entropy, center, ke)

tensor(0.2952)

In [10]:
print(torch.stack([mod_upper1(x.unsqueeze(0), entropy, center, ke) for x in X]).sum()) # upper bound

tensor(1.4762)


In [None]:
# check modular
print(torch.stack([mod_upper1(x.unsqueeze(0), entropy, center, ke) for x in X]).sum()) # upper bound
print(torch.stack([mod_lower1(x.unsqueeze(0), entropy, perm, ke2) for x in X]).sum()) # lower bound

In [None]:
X = mod_mod1(V, entropy, entropy, ke, ke2)
X

In [None]:
# tests for mod_lower
assert mod_lower({2, 4, 5}, linear, [3, 1, 2, 5, 4]) == 11 # V = {1...5}
assert mod_lower({1, 7, 9}, linear, [3, 1, 4, 2, 10, 6, 9, 7, 5, 8]) == 17 # V = {1...10}

assert mod_lower({2, 4, 5}, budget_add, [3, 1, 2, 5, 4]) == 6 # V = {1...5}

# tests for mod_upper
assert mod_upper({2, 4, 5}, linear, {1, 2, 5}) == 11 # V = {1...5}
assert mod_upper({1, 7, 9}, linear, {10, 6, 9, 7}) == 17 # V = {1...10}

# tests for make_perm
assert set(make_perm({1, 2, 3, 4, 5}, {4, 3, 5})[0:3]).issubset({4, 3, 5})
assert set(make_perm({1, 2, 3, 4, 5, 6, 7, 8, 9}, {9, 3, 7})[0:3]).issubset({9, 3, 7})

In [None]:
mod_upper({1, 7, 9}, linear, {10, 6, 9, 7})


In [None]:
# def mod_upper(X, V, fn, center, *args, **kwargs):
#     """ Modular upper bound of fn(X) for any X contained in ground set V, 
#     centered at center.
    
#     Expects X and center as lists of tuples, and fn as a Python function. """
    
#     up = fn(center, *args, **kwargs) # modular upper bound
    
#     for j in center:
#         if j not in X:
#             V_noj = [x for x in V if x != j]
#             up -= fn(V, *args, **kwargs) - fn(V_noj, *args, **kwargs)

#     for j in X:
#         if j not in center:
#             up += fn(center + [j], *args, **kwargs) - fn(center, *args, **kwargs)
            
#     return up