In [1]:
import sys
import numpy as np
import torch
from torch.autograd import gradcheck

import matplotlib
import pandas as pd

import mnp

In [2]:
#Check Gradient
n = 5
layer_sizes = np.ones(3, dtype=np.int64)
layer_sizes *= 3

#Generate input for DSF
s_input, m_input = mnp.gen_deep_submodular_bernoulli(n, layer_sizes, p=0.2)
s_input += 0.5 #Needed for layers - eps to be positive
s_input = torch.tensor(s_input, requires_grad=True)
m_input = torch.tensor(m_input, requires_grad=False)
test = gradcheck(mnp.DeepSubmodular.apply, (s_input, m_input, n, layer_sizes), eps=1e-8, atol=5e-1)
print(test)

#Generate input for LogQ loss
y = torch.rand(n, requires_grad=True, dtype=torch.float64)
y_gt = torch.rand(n, dtype=torch.float64)
test = gradcheck(mnp.LogQ.apply, (y, y_gt), eps=1e-6, atol=1e-4)
print(test)

True
True


In [3]:
#First generate a submodular set function that we want to fit
#For now, generate a submodular set function
n = 100
layer_sizes = np.ones(3, dtype=np.int64)
layer_sizes *= 10

sub_weights_gt, mod_weights_gt = mnp.gen_deep_submodular_bernoulli(n, layer_sizes, p=0.2)
FA_gt, yprime_gt = mnp.mnp_deep_contig_w(n, layer_sizes, sub_weights_gt, mod_weights_gt)

card = np.count_nonzero(yprime_gt <= 0.0)
print('Ground Truth F_A*:\t' + str(FA_gt))
print('Ground Truth |A*|:\t' + str(card))

yprime_gt = torch.tensor(yprime_gt)


Ground Truth F_A*:	-15.596981707785133
Ground Truth |A*|:	48


In [None]:
#Initalize weights/variables
s_weights, m_weights = mnp.gen_deep_submodular_bernoulli(n, layer_sizes,p=0.2)
s_weights = torch.tensor(s_weights, requires_grad=True)
m_weights = torch.tensor(m_weights, requires_grad=True)

#Create Model
deep_mnp = mnp.DeepSubmodular.apply
loss_fn = mnp.LogQ.apply

#Use SGD
optimizer = torch.optim.SGD([s_weights, m_weights], lr=5e-1)

losses = []
best_loss = 0
for iter_num in range(500):
    optimizer.zero_grad()
    
    yprime = deep_mnp(s_weights, m_weights, n, layer_sizes)
    loss = loss_fn(yprime, yprime_gt)

    loss.backward(retain_graph=True)
    optimizer.step()
    s_weights.data.clamp_(min=0.0)
    
    if iter_num % 50 == 0:   
        print('iteration ' + str(iter_num) + ' loss ' + str(loss.item()))
    losses.append(loss.item())
        
    if(iter_num == 0 or loss < best_loss):
        best_loss = loss
        best_s_weights = s_weights
        best_m_weights = m_weights

df = pd.DataFrame(losses, columns=('losses',))
df.plot()
    

iteration 0 loss 72.14332281904609
iteration 50 loss 4.148435375993554
iteration 100 loss 2.047433758092062
iteration 150 loss 1.3552410006078104
iteration 200 loss 1.0119333695221022
iteration 250 loss 0.8070927585566823
iteration 300 loss 0.6710862373006501
iteration 350 loss 0.5742401656723194
