In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import matplotlib.pyplot as plt
from sklearn.datasets import make_moons
from sklearn.model_selection import train_test_split
from tqdm import tqdm
from maraboupy import Marabou, MarabouCore
from prettytable import PrettyTable
import sys
from maraboupy.MarabouNetworkONNX import MarabouNetworkONNX


In [None]:
#simple MLP as our model
class MLP(nn.Module):
    def __init__(self, arch):
        super().__init__()
        layers = []
        for i in range(len(arch) - 1):
            layers.append(nn.Linear(arch[i], arch[i+1]))
            if i < len(arch) - 2: 
                layers.append(nn.ReLU())
        
        self.layers = nn.Sequential(*layers)

    def forward(self, x):
        return self.layers(x)

#subroutine to train model wholesale using Outlier Exposure with many lambda values
def train(X_train, y_train, lams, n_epochs=100, batch_size=128):
    models = {}
    for lam in lams:
        model = MLP([2,32,32,32,2])
        
        optimizer = optim.Adam(model.parameters(), lr=0.01)
        criterion = nn.CrossEntropyLoss()
        
        for epoch in range(n_epochs):
            for i in range(0, len(X_train), batch_size):
                batch_X = torch.tensor(X_train[i:i + batch_size], dtype=torch.float32)
                batch_y = torch.tensor(y_train[i:i + batch_size], dtype=torch.long)
                optimizer.zero_grad()
                outputs = model(batch_X)
                loss = criterion(outputs, batch_y)
                if lam > 0:  
                    outlier_outputs = model(torch.tensor(outliers, dtype=torch.float32))
                    outlier_loss = lam * torch.mean(torch.sum(outlier_outputs ** 2, dim=1))
                    loss += outlier_loss
                loss.backward()
                optimizer.step()
        models[lam] = model
    return models

In [None]:
def plot_conf(models, X, y, lams, zoom, fname=None):
    rows = (len(lams) + 3) // 4  
    cols = min(len(lams), 4)     
    fig, axs = plt.subplots(nrows=rows, ncols=cols, figsize=(5 * cols, 4 * rows), squeeze=False)
    
    for i, lam in enumerate(lams):
        ax = axs[i // cols, i % cols]
        model = models[lam]
        x_span = np.linspace(X[:, 0].min() - zoom, X[:, 0].max() + zoom, 400)
        y_span = np.linspace(X[:, 1].min() - zoom, X[:, 1].max() + zoom, 400)
        xx, yy = np.meshgrid(x_span, y_span)
        grid = torch.tensor(np.c_[xx.ravel(), yy.ravel()], dtype=torch.float32)
        
        with torch.no_grad():
            pred_func = model(grid)
        
        Z = np.argmax(pred_func.numpy(), axis=1).reshape(xx.shape)
        conf = torch.softmax(pred_func, dim=1).numpy().max(axis=1).reshape(xx.shape)
        
        conf_plot = ax.contourf(xx, yy, conf, alpha=0.7, levels=np.linspace(0, 1, 50), cmap='viridis')
        ax.scatter(X[:, 0], X[:, 1], c=y, cmap=plt.cm.bwr, alpha=0.7)
        ax.set_title(f'$\lambda = {lam}$', fontsize=14)
        
        cbar = fig.colorbar(conf_plot, ax=ax, ticks=np.linspace(0, 1, 11))
    
    for j in range(i + 1, rows * cols):
        if i // cols < rows:
            axs[j // cols, j % cols].axis('off')

    plt.tight_layout()

    if fname is not None:
        plt.savefig(f'{fname}.pdf', format='pdf')

    plt.show()

In [None]:
X, y = make_moons(n_samples=1000, noise=0.06, random_state=1)
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=1)

n_outliers = 100
c_mus = [np.mean(X[y == k], axis=0) for k in range(2)]
cov = np.cov(X.T)
eps = 0.001

outliers = []
test_outliers = []

for k in range(2):
    class_outliers = []
    test_class_outliers = []

    while len(class_outliers) < n_outliers // 2:
        outlier = np.random.multivariate_normal(c_mus[k], cov)
        likelihood = np.exp(-0.5 * np.dot(np.dot(outlier - c_mus[k], np.linalg.inv(cov)), outlier - c_mus[k])) / (2 * np.pi * np.sqrt(np.linalg.det(cov)))
        if likelihood < eps:
            class_outliers.append(outlier)

    test_cov = cov * 2.0 
    while len(test_class_outliers) < n_outliers // 2:
        test_outlier = np.random.multivariate_normal(c_mus[k], test_cov)
        test_likelihood = np.exp(-0.5 * np.dot(np.dot(test_outlier - c_mus[k], np.linalg.inv(test_cov)), test_outlier - c_mus[k])) / (2 * np.pi * np.sqrt(np.linalg.det(test_cov)))
        if test_likelihood < eps:
            test_class_outliers.append(test_outlier)

    outliers.extend(class_outliers)
    test_outliers.extend(test_class_outliers)

outliers = np.array(outliers)
test_outliers = np.array(test_outliers)

In [None]:
fig, axs = plt.subplots(nrows=1, ncols=2, figsize=(10, 4))

axs[0].scatter(X_train[:, 0], X_train[:, 1], c=y_train, cmap=plt.cm.bwr, alpha=0.7)
axs[0].set_title("Original")

axs[1].scatter(X_train[:, 0], X_train[:, 1], c=y_train, cmap=plt.cm.bwr, alpha=0.7)
axs[1].scatter(outliers[:, 0], outliers[:, 1], color='green', marker='x', alpha=0.7)
axs[1].scatter(test_outliers[:, 0], test_outliers[:, 1], color='orange', marker='x', alpha=0.7)
axs[1].set_title("With outliers")

axs[1].legend(["Training data", "Training outliers", "Test outliers"], loc="upper right")

plt.savefig('OE.pdf', format='pdf')
plt.tight_layout()
plt.show()

In [None]:
lams = [0,0.1,0.5,1,5,10,20,50]
models = train(X_train, y_train, lams)

In [None]:
plot_conf(models, X_train, y_train, lams, zoom=30, fname="conff1")

In [None]:
def verify_ood_rejection(model, X_OOD, tau, f):
    results = []
    for x in X_OOD:
        xin = torch.tensor(x, dtype=torch.float32).unsqueeze(0)
        onnx_f = "model.onnx"
        torch.onnx.export(model, xin, onnx_f, input_names=['input'], output_names=['output'])
        network = Marabou.read_onnx(onnx_f)
        
        out_vars = network.outputVars[0].flatten()
        
        out_constraints = []
        for i in range(out_vars.shape[0]):
            constraint = MarabouCore.Equation(MarabouCore.Equation.GE)
            for j in range(out_vars.shape[0]):
                if i != j:
                    constraint.addAddend(1, out_vars[j])
            constraint.addAddend(-1, out_vars[i])
            constraint.setScalar(0)
            out_constraints.append(constraint)
        
        options = Marabou.createOptions(verbosity=0)
        exit_code, vals, stats = network.solve(options=options)
        
        if exit_code == "sat":
            softmax_vals = torch.exp(torch.tensor(vals[out_vars[0][0]:], dtype=torch.float32))
            softmax_vals /= torch.sum(softmax_vals)
            results.append(torch.max(softmax_vals).item() <= tau)
        else:
            results.append(True)
    
    return np.mean(results)

def verify_robustness(model, X, epsilon, tau, f):
    results = []
    for x in X:
        xin = torch.tensor(x, dtype=torch.float32).unsqueeze(0)
        onnx_f = "model.onnx"
        torch.onnx.export(model, xin, onnx_f, input_names=['input'], output_names=['output'])
        network = Marabou.read_onnx(onnx_f)
        
        in_vars = network.inputVars[0]
        out_vars = network.outputVars[0].flatten()
        
        in_query = MarabouCore.InputQuery()
        in_query.setNumberOfVariables(len(in_vars))
        
        for i in range(in_vars.shape[0]):
            in_query.setLowerBound(in_vars[i][0], x[i] - epsilon)
            in_query.setUpperBound(in_vars[i][0], x[i] + epsilon)
        
        out_constraints = []
        for i in range(out_vars.shape[0]):
            constraint = MarabouCore.Equation(MarabouCore.Equation.GE)
            for j in range(out_vars.shape[0]):
                if i != j:
                    constraint.addAddend(1, out_vars[j])
            constraint.addAddend(-1, out_vars[i])
            constraint.setScalar(0)
            out_constraints.append(constraint)
        
        options = Marabou.createOptions(verbosity=0)
        exit_code, vals, stats = network.solve(in_query, out_constraints, options)
        
        if exit_code == "sat":
            softmax_vals = torch.exp(torch.tensor(vals[out_vars[0][0]:], dtype=torch.float32))
            softmax_vals /= torch.sum(softmax_vals)
            results.append(torch.max(softmax_vals).item() >= tau)
        else:
            results.append(True)
    
    return np.mean(results)

def verify_monotonicity(model, X1, X2, X_ID, f):
    results = []
    for x1, x2 in zip(X1, X2):
        d1 = np.min(np.linalg.norm(X_ID - x1, axis=1))
        d2 = np.min(np.linalg.norm(X_ID - x2, axis=1))

        if d1 <= d2:
            results.append(True)
            continue

        xin1 = torch.tensor(x1, dtype=torch.float32).unsqueeze(0)
        xin2 = torch.tensor(x2, dtype=torch.float32).unsqueeze(0)

        onnx_f = "model.onnx"
        torch.onnx.export(model, xin1, onnx_f, input_names=['input'], output_names=['output'])
        network = Marabou.read_onnx(onnx_f)

        out_vars1 = network.outputVars[0].flatten()
        out_vars2 = network.outputVars[0].flatten()

        out_constraint = MarabouCore.Equation(MarabouCore.Equation.GE)
        for i in range(out_vars1.shape[0]):
            out_constraint.addAddend(1, out_vars1[i])
            out_constraint.addAddend(-1, out_vars2[i])
        out_constraint.setScalar(0)

        options = Marabou.createOptions(verbosity=0)
        exit_code, vals, stats = network.solve(options=options)

        results.append(exit_code == "unsat")

    return np.mean(results)


In [None]:
#test
n_ood = 5
n_robust = 5
n_mono = 5
tau=0.8
epsilon=0.5

results = []

for lam, model in models.items():
    # OOD rejection
    with open(f"ood_rejection_lambda_{lam}.txt", "w") as f:
        X_OOD_subset = test_outliers[:n_ood]
        ood_rejection_rate = verify_ood_rejection(model, X_OOD_subset, tau, f)

    # robustness
    with open(f"robustness_lambda_{lam}.txt", "w") as f:
        X_robust_subset = X_test[:n_robust]
        robustness_rate = verify_robustness(model, X_robust_subset, epsilon, tau, f)

    # monotonicity
    with open(f"monotonicity_lambda_{lam}.txt", "w") as f:
        X_mono_subset1 = X_test[:n_mono]
        X_mono_subset2 = X_test[n_mono:2*n_mono]
        monotonicity_rate = verify_monotonicity(model, X_mono_subset1, X_mono_subset2, X_train, f)

    results.append((lam, ood_rejection_rate, robustness_rate, monotonicity_rate))

table = PrettyTable()
table.field_names = ["Lambda", "OOD rejection rate", "Robustness rate", "Monotonicity rate"]
for result in results:
    table.add_row(result)

print("Results:")
print(table)

lambdas = [result[0] for result in results]
ood_rejection_rates = [result[1] for result in results]
robustness_rates = [result[2] for result in results]
monotonicity_rates = [result[3] for result in results]

plt.figure(figsize=(24, 6))

plt.subplot(1, 3, 1)
plt.plot(lambdas, ood_rejection_rates, marker='o')
plt.xlabel('Lambda')
plt.ylabel('OOD rejection rate')

plt.subplot(1, 3, 2)
plt.plot(lambdas, robustness_rates, marker='s')
plt.xlabel('Lambda')
plt.ylabel('Robustness rate')


plt.subplot(1, 3, 3)
plt.plot(lambdas, monotonicity_rates, marker='^')
plt.xlabel('Lambda')
plt.ylabel('Monotonicity rate')

plt.tight_layout()
plt.show()