Implements algorithm 3 (ModMod) from Algorithms for "Approx Min of the Difference Between Submodular Fncs with Applications."

In [1]:
import torch
import itertools
import pickle

import itertools
import random

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
%matplotlib inline
sns.set_style('white')
sns.set_context('paper')
# Plot adjustments:
plt.rcParams.update({'ytick.labelsize': 15})
plt.rcParams.update({'xtick.labelsize': 15})
plt.rcParams.update({'axes.labelsize': 35})
plt.rcParams.update({'legend.fontsize': 30})
plt.rcParams.update({'axes.titlesize': 16})

In [2]:
def mod_lower(X, fn, perm):
    """ Modular lower bound of fn(X) for any X contained in ground set V
    with permutation chain perm (aka S).
    
    Expects X and center as sets, fn as a Python function, and perm as a list. """
    
    low = 0.0 # lower modular bound
    
    for elem in X:
        i = perm.index(elem)
        if i == 0:
            low += fn({perm[0]})
        else:
            low += fn(set(perm[0:i + 1])) - fn(set(perm[0:i]))
            
    return low

In [3]:
def mod_upper(X, fn, center):
    """ Modular upper bound of fn(X) for any X contained in ground set V, 
    centered at center.
    
    Expects X and center as sets, and fn as a Python function. """
    
    up = fn(center) # modular upper bound
    
    for j in center:
        if j not in X:
            center_noj = [x for x in center if x != j]
            up -= fn(center) - fn(center_noj)
    
    for j in X:
        if j not in center:
            up += fn({j}) - fn({})
            
    return up



In [8]:
def make_perm(V, X):
    """ Takes in the ground set V and a set X, and
    returns a random chain permutation that contains X """
    lst_perms = list(itertools.permutations(V)) # create a list of all possible permutations of V
    # select possible perms that contain X and include all elements in V
    poss_perms = [perm for perm in lst_perms if X.issubset(perm[0:len(X)]) and len(perm) == len(V)]
    return list(random.choice(poss_perms)) # return random perm from list of possible perms

def algo3(V, fn, g):
    """ Implements algorithm3 (ModMod) from paper. Takes in ground set V,
    and functions fn and g. 
    
    Expects V as a set, and fn and g as submodular Python functions. """
    X, X_next = set(), set()
    while True:
        perm = make_perm(V, X) # choose permutation
        for i in V:
            obj = mod_upper({i}, fn, X) - mod_lower({i}, g, perm) # objective with element in V as input
            emp = mod_upper({}, fn, X) - mod_lower({}, g, perm) # objective with empty set as input
            if obj < emp:
                X_next.add(i)
        if X_next == X:
            break
        else:
            X = X_next
    
    return X_next

In [11]:
# two examples of submodular set functions

def linear(X):
    """ A function f : 2^N → R is linear if f(X) = sum (i∈X) of wi for some weights
    w : N → R. If wi ≥ 0 for all i ∈ N, then f is also monotone. """
    
    if len(X) == 0: # empty set
        return -1
    
    sum_w = 0
    for i in X:
        sum_w += i
        
    return sum_w
    
def budget_add(X, B=10):
    """ A small generalization of the linear case, the function f(A) =
    min{sum (i∈A) of wi, B} for any wi ≥ 0 and B ≥ 0, is monotone submodular."""
    
    if len(X) == 0: # empty set
        return -10
    
    sum_w = 0
    for i in X:
        sum_w += i
    
    return min(sum_w, B)

In [None]:
# tests for mod_lower
assert mod_lower({2, 4, 5}, linear, [3, 1, 2, 5, 4]) == 11 # V = {1...5}
assert mod_lower({1, 7, 9}, linear, [3, 1, 4, 2, 10, 6, 9, 7, 5, 8]) == 17 # V = {1...10}

assert mod_lower({2, 4, 5}, budget_add, [3, 1, 2, 5, 4]) == 6 # V = {1...5}

# tests for mod_upper
assert mod_upper({2, 4, 5}, linear, {1, 2, 5}) == 12 # V = {1...5}
assert mod_upper({1, 7, 9}, linear, {10, 6, 9, 7}) == 18 # V = {1...10}

# tests for make_perm
assert set(make_perm({1, 2, 3, 4, 5}, {4, 3, 5})[0:3]).issubset({4, 3, 5})
assert set(make_perm({1, 2, 3, 4, 5, 6, 7, 8, 9}, {9, 3, 7})[0:3]).issubset({9, 3, 7})